Compare commits
7 Commits
jaeone/fe-
...
pysssss/mo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
230dfcd32c | ||
|
|
7e2826bf7a | ||
|
|
307e8a0d04 | ||
|
|
e2c9550664 | ||
|
|
a87e107b01 | ||
|
|
c9f1cc42ad | ||
|
|
6cc1f20bd4 |
|
Before Width: | Height: | Size: 6.5 KiB After Width: | Height: | Size: 3.2 KiB |
|
Before Width: | Height: | Size: 3.2 KiB After Width: | Height: | Size: 6.6 KiB |
|
Before Width: | Height: | Size: 15 KiB After Width: | Height: | Size: 15 KiB |
|
Before Width: | Height: | Size: 56 KiB After Width: | Height: | Size: 1.2 KiB |
@@ -1,23 +1,23 @@
|
||||
{
|
||||
"name": "Comfy",
|
||||
"short_name": "Comfy",
|
||||
"id": "/",
|
||||
"start_url": "/",
|
||||
"icons": [
|
||||
{
|
||||
"src": "/web-app-manifest-192x192.png",
|
||||
"sizes": "192x192",
|
||||
"type": "image/png",
|
||||
"purpose": "any maskable"
|
||||
"purpose": "any"
|
||||
},
|
||||
{
|
||||
"src": "/web-app-manifest-512x512.png",
|
||||
"sizes": "512x512",
|
||||
"type": "image/png",
|
||||
"purpose": "any maskable"
|
||||
"purpose": "any"
|
||||
}
|
||||
],
|
||||
"theme_color": "#211927",
|
||||
"background_color": "#211927",
|
||||
"display": "standalone",
|
||||
"id": "/",
|
||||
"start_url": "/"
|
||||
"display": "standalone"
|
||||
}
|
||||
|
||||
|
Before Width: | Height: | Size: 3.9 KiB After Width: | Height: | Size: 13 KiB |
|
Before Width: | Height: | Size: 8.8 KiB After Width: | Height: | Size: 38 KiB |
@@ -8,7 +8,6 @@ import {
|
||||
useDownloadUrl
|
||||
} from '../../../composables/useDownloadUrl'
|
||||
import { t } from '../../../i18n/translations'
|
||||
import { captureDownloadClick } from '../../../scripts/posthog'
|
||||
import BrandButton from '../../common/BrandButton.vue'
|
||||
|
||||
const { locale = 'en', class: customClass = '' } = defineProps<{
|
||||
@@ -70,7 +69,6 @@ const buttons = computed<ButtonSpec[]>(() => {
|
||||
size="lg"
|
||||
:class="customClass"
|
||||
:aria-label="btn.ariaLabel"
|
||||
@click="captureDownloadClick(btn.key)"
|
||||
>
|
||||
<span class="inline-flex items-center gap-2">
|
||||
<img
|
||||
|
||||
@@ -73,7 +73,7 @@ const websiteJsonLd = {
|
||||
|
||||
<link rel="icon" type="image/png" href="/favicon-96x96.png" sizes="96x96" />
|
||||
<link rel="icon" type="image/svg+xml" href="/favicon.svg" />
|
||||
<link rel="shortcut icon" href="/favicon.ico" sizes="48x48" />
|
||||
<link rel="shortcut icon" href="/favicon.ico" />
|
||||
<link rel="apple-touch-icon" sizes="180x180" href="/apple-touch-icon.png" />
|
||||
<link rel="manifest" href="/site.webmanifest" />
|
||||
<meta name="theme-color" content="#211927" />
|
||||
|
||||
@@ -53,28 +53,3 @@ describe('initPostHog', () => {
|
||||
expect(result.$set_once).toHaveProperty('plan', 'free')
|
||||
})
|
||||
})
|
||||
|
||||
describe('captureDownloadClick', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.resetModules()
|
||||
})
|
||||
|
||||
it('captures the download event with the platform', async () => {
|
||||
const { initPostHog, captureDownloadClick } = await import('./posthog')
|
||||
initPostHog()
|
||||
captureDownloadClick('mac')
|
||||
|
||||
expect(hoisted.mockCapture).toHaveBeenCalledWith(
|
||||
'website:download_button_clicked',
|
||||
{ platform: 'mac' }
|
||||
)
|
||||
})
|
||||
|
||||
it('does not capture before PostHog is initialized', async () => {
|
||||
const { captureDownloadClick } = await import('./posthog')
|
||||
captureDownloadClick('windows')
|
||||
|
||||
expect(hoisted.mockCapture).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -38,12 +38,3 @@ export function capturePageview() {
|
||||
console.error('PostHog pageview capture failed', error)
|
||||
}
|
||||
}
|
||||
|
||||
export function captureDownloadClick(platform: string) {
|
||||
if (!initialized) return
|
||||
try {
|
||||
posthog.capture('website:download_button_clicked', { platform })
|
||||
} catch (error) {
|
||||
console.error('PostHog download click capture failed', error)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
{
|
||||
"last_node_id": 2,
|
||||
"last_link_id": 0,
|
||||
"nodes": [
|
||||
{
|
||||
"id": 1,
|
||||
"type": "E2E_OldSampler",
|
||||
"pos": [100, 100],
|
||||
"size": [400, 262],
|
||||
"flags": {},
|
||||
"order": 0,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{ "name": "model", "type": "MODEL", "link": null },
|
||||
{ "name": "positive", "type": "CONDITIONING", "link": null },
|
||||
{ "name": "negative", "type": "CONDITIONING", "link": null },
|
||||
{ "name": "latent_image", "type": "LATENT", "link": null }
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "LATENT",
|
||||
"type": "LATENT",
|
||||
"links": [],
|
||||
"slot_index": 0
|
||||
}
|
||||
],
|
||||
"properties": { "Node name for S&R": "E2E_OldSampler" },
|
||||
"widgets_values": [42, 20, 7, "euler", "normal"]
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "E2E_OldSampler",
|
||||
"pos": [520, 100],
|
||||
"size": [400, 262],
|
||||
"flags": {},
|
||||
"order": 1,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{ "name": "model", "type": "MODEL", "link": null },
|
||||
{ "name": "positive", "type": "CONDITIONING", "link": null },
|
||||
{ "name": "negative", "type": "CONDITIONING", "link": null },
|
||||
{ "name": "latent_image", "type": "LATENT", "link": null }
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "LATENT",
|
||||
"type": "LATENT",
|
||||
"links": [],
|
||||
"slot_index": 0
|
||||
}
|
||||
],
|
||||
"properties": { "Node name for S&R": "E2E_OldSampler" },
|
||||
"widgets_values": [43, 20, 7, "euler", "normal"]
|
||||
}
|
||||
],
|
||||
"links": [],
|
||||
"groups": [],
|
||||
"config": {},
|
||||
"extra": { "ds": { "scale": 1, "offset": [0, 0] } },
|
||||
"version": 0.4
|
||||
}
|
||||
@@ -61,14 +61,15 @@ export const TestIds = {
|
||||
missingModelsGroup: 'error-group-missing-model',
|
||||
missingModelExpand: 'missing-model-expand',
|
||||
missingModelLocate: 'missing-model-locate',
|
||||
missingModelReferenceCount: 'missing-model-reference-count',
|
||||
missingModelCopyName: 'missing-model-copy-name',
|
||||
missingModelCopyUrl: 'missing-model-copy-url',
|
||||
missingModelDownload: 'missing-model-download',
|
||||
missingModelActions: 'missing-model-actions',
|
||||
missingModelDownloadAll: 'missing-model-download-all',
|
||||
missingModelRefresh: 'missing-model-header-refresh',
|
||||
missingModelRefresh: 'missing-model-refresh',
|
||||
missingModelImportUnsupported: 'missing-model-import-unsupported',
|
||||
missingMediaGroup: 'error-group-missing-media',
|
||||
swapNodesGroup: 'error-group-swap-nodes',
|
||||
swapNodeGroupCount: 'swap-node-group-count',
|
||||
missingMediaRow: 'missing-media-row',
|
||||
missingMediaLocateButton: 'missing-media-locate-button',
|
||||
publishTabPanel: 'publish-tab-panel',
|
||||
|
||||
61
browser_tests/tests/browseModelAssets.spec.ts
Normal file
@@ -0,0 +1,61 @@
|
||||
import { expect } from '@playwright/test'
|
||||
|
||||
import type { Asset } from '@comfyorg/ingest-types'
|
||||
import { createCloudAssetsFixture } from '@e2e/fixtures/assetApiFixture'
|
||||
import { STABLE_CHECKPOINT } from '@e2e/fixtures/data/assetFixtures'
|
||||
|
||||
const CLOUD_ASSETS: Asset[] = [STABLE_CHECKPOINT]
|
||||
|
||||
const test = createCloudAssetsFixture(CLOUD_ASSETS)
|
||||
|
||||
test.describe('Browse Model Assets - Use button', { tag: '@cloud' }, () => {
|
||||
test.beforeEach(async ({ comfyPage }) => {
|
||||
await comfyPage.settings.setSetting('Comfy.Assets.UseAssetAPI', true)
|
||||
await comfyPage.nodeOps.clearGraph()
|
||||
})
|
||||
|
||||
test.afterEach(async ({ comfyPage }) => {
|
||||
await comfyPage.nodeOps.clearGraph()
|
||||
})
|
||||
|
||||
test('Use button ghost-places a loader populated with the model', async ({
|
||||
comfyPage
|
||||
}) => {
|
||||
await comfyPage.command.executeCommand('Comfy.BrowseModelAssets')
|
||||
|
||||
const modal = comfyPage.page.locator(
|
||||
'[data-component-id="AssetBrowserModal"]'
|
||||
)
|
||||
await expect(modal).toBeVisible()
|
||||
|
||||
const card = comfyPage.page.locator(
|
||||
`[data-component-id="AssetCard"][data-asset-id="${STABLE_CHECKPOINT.id}"]`
|
||||
)
|
||||
await expect(card).toBeVisible()
|
||||
await card.getByRole('button', { name: 'Use' }).click()
|
||||
|
||||
// Dialog closes and the ghost is armed; the node is not placed until the
|
||||
// user clicks the canvas.
|
||||
await expect(modal).toBeHidden()
|
||||
await expect
|
||||
.poll(() => comfyPage.nodeOps.getGraphNodesCount(), { timeout: 1000 })
|
||||
.toBe(0)
|
||||
|
||||
const canvasBox = (await comfyPage.canvas.boundingBox())!
|
||||
await comfyPage.canvas.click({
|
||||
position: { x: canvasBox.width / 2, y: canvasBox.height / 2 }
|
||||
})
|
||||
|
||||
await expect.poll(() => comfyPage.nodeOps.getGraphNodesCount()).toBe(1)
|
||||
await expect
|
||||
.poll(() => comfyPage.nodeOps.getSelectedGraphNodesCount())
|
||||
.toBe(1)
|
||||
|
||||
const [loader] = await comfyPage.nodeOps.getNodeRefsByType(
|
||||
'CheckpointLoaderSimple'
|
||||
)
|
||||
expect(loader).toBeDefined()
|
||||
const widget = await loader.getWidgetByName('ckpt_name')
|
||||
expect(await widget.getValue()).toBe(STABLE_CHECKPOINT.name)
|
||||
})
|
||||
})
|
||||
@@ -48,36 +48,6 @@ test.describe('Node replacement', { tag: ['@node', '@ui'] }, () => {
|
||||
).toBeVisible()
|
||||
})
|
||||
|
||||
test('Shows direct row label and locate action for a single replacement group', async ({
|
||||
comfyPage
|
||||
}) => {
|
||||
const swapGroup = getSwapNodesGroup(comfyPage.page)
|
||||
const rowLabel = swapGroup.getByRole('button', {
|
||||
name: 'E2E_OldSampler',
|
||||
exact: true
|
||||
})
|
||||
|
||||
await expect(rowLabel).toBeVisible()
|
||||
await expect(
|
||||
swapGroup.getByRole('button', {
|
||||
name: 'Locate node on canvas',
|
||||
exact: true
|
||||
})
|
||||
).toBeVisible()
|
||||
await expect(
|
||||
swapGroup.getByTestId(TestIds.dialogs.swapNodeGroupCount)
|
||||
).toHaveCount(0)
|
||||
|
||||
await comfyPage.canvasOps.pan({ x: -800, y: -800 })
|
||||
const offsetBeforeLocate = await comfyPage.canvasOps.getOffset()
|
||||
|
||||
await rowLabel.click()
|
||||
|
||||
await expect
|
||||
.poll(() => comfyPage.canvasOps.getOffset())
|
||||
.not.toEqual(offsetBeforeLocate)
|
||||
})
|
||||
|
||||
test('Replace Node replaces a single group in-place', async ({
|
||||
comfyPage
|
||||
}) => {
|
||||
@@ -146,55 +116,6 @@ test.describe('Node replacement', { tag: ['@node', '@ui'] }, () => {
|
||||
})
|
||||
})
|
||||
|
||||
test.describe('Same-type replacement group', () => {
|
||||
test.beforeEach(async ({ comfyPage }) => {
|
||||
await comfyPage.settings.setSetting(
|
||||
'Comfy.VueNodes.Enabled',
|
||||
mode.vueNodesEnabled
|
||||
)
|
||||
await setupNodeReplacement(comfyPage, mockNodeReplacementsSingle)
|
||||
await loadWorkflowAndOpenErrorsTab(
|
||||
comfyPage,
|
||||
'missing/node_replacement_same_type'
|
||||
)
|
||||
})
|
||||
|
||||
test('Groups same-type replacement rows behind the title disclosure', async ({
|
||||
comfyPage
|
||||
}) => {
|
||||
const swapGroup = getSwapNodesGroup(comfyPage.page)
|
||||
const countBadge = swapGroup.getByTestId(
|
||||
TestIds.dialogs.swapNodeGroupCount
|
||||
)
|
||||
const childRows = swapGroup.getByRole('listitem')
|
||||
const expandButton = swapGroup.getByRole('button', {
|
||||
name: 'Expand E2E_OldSampler',
|
||||
exact: true
|
||||
})
|
||||
|
||||
await expect(expandButton).toBeVisible()
|
||||
await expect(countBadge).toHaveText('2')
|
||||
await expect(childRows).toHaveCount(0)
|
||||
|
||||
await expandButton.click()
|
||||
await expect(childRows).toHaveCount(2)
|
||||
await expect(
|
||||
swapGroup.getByRole('button', {
|
||||
name: 'E2E_OldSampler',
|
||||
exact: true
|
||||
})
|
||||
).toHaveCount(2)
|
||||
|
||||
await swapGroup
|
||||
.getByRole('button', {
|
||||
name: 'Collapse E2E_OldSampler',
|
||||
exact: true
|
||||
})
|
||||
.click()
|
||||
await expect(childRows).toHaveCount(0)
|
||||
})
|
||||
})
|
||||
|
||||
test.describe('Multi-type replacement', () => {
|
||||
test.beforeEach(async ({ comfyPage }) => {
|
||||
await comfyPage.settings.setSetting(
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import { expect } from '@playwright/test'
|
||||
import type { Locator } from '@playwright/test'
|
||||
|
||||
import { comfyPageFixture as test } from '@e2e/fixtures/ComfyPage'
|
||||
import { TestIds } from '@e2e/fixtures/selectors'
|
||||
@@ -12,18 +11,6 @@ import {
|
||||
loadWorkflowAndOpenErrorsTab
|
||||
} from '@e2e/fixtures/helpers/ErrorsTabHelper'
|
||||
|
||||
const FAKE_MODEL_NAME = 'fake_model.safetensors'
|
||||
|
||||
function getModelLabel(group: Locator, modelName: string = FAKE_MODEL_NAME) {
|
||||
return group.getByRole('button', { name: modelName, exact: true })
|
||||
}
|
||||
|
||||
async function expectReferenceBadge(group: Locator, count: number) {
|
||||
await expect(
|
||||
group.getByTestId(TestIds.dialogs.missingModelReferenceCount)
|
||||
).toHaveText(String(count))
|
||||
}
|
||||
|
||||
test.describe('Errors tab - Missing models', { tag: '@ui' }, () => {
|
||||
test.beforeEach(async ({ comfyPage }) => {
|
||||
await comfyPage.settings.setSetting(
|
||||
@@ -47,14 +34,15 @@ test.describe('Errors tab - Missing models', { tag: '@ui' }, () => {
|
||||
).toHaveText(/\S/)
|
||||
})
|
||||
|
||||
test('Should display model name and metadata', async ({ comfyPage }) => {
|
||||
test('Should display model name with referencing node count', async ({
|
||||
comfyPage
|
||||
}) => {
|
||||
await loadWorkflowAndOpenErrorsTab(comfyPage, 'missing/missing_models')
|
||||
|
||||
const modelsGroup = comfyPage.page.getByTestId(
|
||||
TestIds.dialogs.missingModelsGroup
|
||||
)
|
||||
await expect(getModelLabel(modelsGroup)).toBeVisible()
|
||||
await expect(modelsGroup.getByText('checkpoints')).toBeVisible()
|
||||
await expect(modelsGroup).toContainText(/fake_model\.safetensors\s*\(\d+\)/)
|
||||
})
|
||||
|
||||
test('Should expand model row to show referencing nodes', async ({
|
||||
@@ -65,33 +53,32 @@ test.describe('Errors tab - Missing models', { tag: '@ui' }, () => {
|
||||
'missing/missing_models_with_nodes'
|
||||
)
|
||||
|
||||
const modelsGroup = comfyPage.page.getByTestId(
|
||||
TestIds.dialogs.missingModelsGroup
|
||||
const locateButton = comfyPage.page.getByTestId(
|
||||
TestIds.dialogs.missingModelLocate
|
||||
)
|
||||
const expandButton = modelsGroup.getByTestId(
|
||||
await expect(locateButton.first()).toBeHidden()
|
||||
|
||||
const expandButton = comfyPage.page.getByTestId(
|
||||
TestIds.dialogs.missingModelExpand
|
||||
)
|
||||
await expect(expandButton.first()).toBeVisible()
|
||||
await expectReferenceBadge(modelsGroup, 2)
|
||||
await expandButton.first().click()
|
||||
|
||||
await expect(
|
||||
modelsGroup.getByTestId(TestIds.dialogs.missingModelLocate)
|
||||
).toHaveCount(2)
|
||||
await expect(locateButton.first()).toBeVisible()
|
||||
})
|
||||
|
||||
test('Should copy model URL to clipboard', async ({ comfyPage }) => {
|
||||
test('Should copy model name to clipboard', async ({ comfyPage }) => {
|
||||
await loadWorkflowAndOpenErrorsTab(comfyPage, 'missing/missing_models')
|
||||
await interceptClipboardWrite(comfyPage.page)
|
||||
|
||||
const copyButton = comfyPage.page.getByRole('button', {
|
||||
name: 'Copy URL'
|
||||
})
|
||||
const copyButton = comfyPage.page.getByTestId(
|
||||
TestIds.dialogs.missingModelCopyName
|
||||
)
|
||||
await expect(copyButton.first()).toBeVisible()
|
||||
await copyButton.first().dispatchEvent('click')
|
||||
|
||||
const copiedText = await getClipboardText(comfyPage.page)
|
||||
expect(copiedText).toContain('/api/devtools/')
|
||||
expect(copiedText).toContain('fake_model.safetensors')
|
||||
})
|
||||
|
||||
test.describe('OSS-specific', { tag: '@oss' }, () => {
|
||||
@@ -100,9 +87,9 @@ test.describe('Errors tab - Missing models', { tag: '@ui' }, () => {
|
||||
}) => {
|
||||
await loadWorkflowAndOpenErrorsTab(comfyPage, 'missing/missing_models')
|
||||
|
||||
const copyUrlButton = comfyPage.page.getByRole('button', {
|
||||
name: 'Copy URL'
|
||||
})
|
||||
const copyUrlButton = comfyPage.page.getByTestId(
|
||||
TestIds.dialogs.missingModelCopyUrl
|
||||
)
|
||||
await expect(copyUrlButton.first()).toBeVisible()
|
||||
})
|
||||
|
||||
@@ -115,7 +102,6 @@ test.describe('Errors tab - Missing models', { tag: '@ui' }, () => {
|
||||
TestIds.dialogs.missingModelDownload
|
||||
)
|
||||
await expect(downloadButton.first()).toBeVisible()
|
||||
await expect(downloadButton.first()).toHaveText('Download')
|
||||
})
|
||||
|
||||
test('Should render Download all and Refresh actions for one downloadable model', async ({
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import { expect } from '@playwright/test'
|
||||
import type { Locator } from '@playwright/test'
|
||||
|
||||
import { comfyPageFixture as test } from '@e2e/fixtures/ComfyPage'
|
||||
import { TestIds } from '@e2e/fixtures/selectors'
|
||||
@@ -9,18 +8,6 @@ import {
|
||||
loadWorkflowAndOpenErrorsTab
|
||||
} from '@e2e/fixtures/helpers/ErrorsTabHelper'
|
||||
|
||||
const FAKE_MODEL_NAME = 'fake_model.safetensors'
|
||||
|
||||
function getModelLabel(group: Locator, modelName: string = FAKE_MODEL_NAME) {
|
||||
return group.getByRole('button', { name: modelName, exact: true })
|
||||
}
|
||||
|
||||
async function expectReferenceBadge(group: Locator, count: number) {
|
||||
await expect(
|
||||
group.getByTestId(TestIds.dialogs.missingModelReferenceCount)
|
||||
).toHaveText(String(count))
|
||||
}
|
||||
|
||||
test.describe('Errors tab - Mode-aware errors', { tag: '@ui' }, () => {
|
||||
test.beforeEach(async ({ comfyPage }) => {
|
||||
await comfyPage.settings.setSetting('Comfy.UseNewMenu', 'Top')
|
||||
@@ -143,9 +130,9 @@ test.describe('Errors tab - Mode-aware errors', { tag: '@ui' }, () => {
|
||||
'missing/missing_models_from_node_properties'
|
||||
)
|
||||
|
||||
const copyUrlButton = comfyPage.page.getByRole('button', {
|
||||
name: 'Copy URL'
|
||||
})
|
||||
const copyUrlButton = comfyPage.page.getByTestId(
|
||||
TestIds.dialogs.missingModelCopyUrl
|
||||
)
|
||||
await expect(copyUrlButton.first()).toBeVisible()
|
||||
|
||||
const node = await comfyPage.nodeOps.getNodeRefById('1')
|
||||
@@ -169,7 +156,9 @@ test.describe('Errors tab - Mode-aware errors', { tag: '@ui' }, () => {
|
||||
TestIds.dialogs.missingModelsGroup
|
||||
)
|
||||
await expect(missingModelGroup).toBeVisible()
|
||||
await expect(getModelLabel(missingModelGroup)).toBeVisible()
|
||||
await expect(missingModelGroup).toContainText(
|
||||
/fake_model\.safetensors\s*\(1\)/
|
||||
)
|
||||
|
||||
const node = await comfyPage.nodeOps.getNodeRefById('1')
|
||||
await node.click('title')
|
||||
@@ -179,7 +168,9 @@ test.describe('Errors tab - Mode-aware errors', { tag: '@ui' }, () => {
|
||||
await expect.poll(() => comfyPage.nodeOps.getNodeCount()).toBe(2)
|
||||
|
||||
await comfyPage.canvas.click()
|
||||
await expectReferenceBadge(missingModelGroup, 2)
|
||||
await expect(missingModelGroup).toContainText(
|
||||
/fake_model\.safetensors\s*\(2\)/
|
||||
)
|
||||
})
|
||||
|
||||
test('Pasting a bypassed node does not add a new error', async ({
|
||||
@@ -261,17 +252,14 @@ test.describe('Errors tab - Mode-aware errors', { tag: '@ui' }, () => {
|
||||
const missingModelGroup = comfyPage.page.getByTestId(
|
||||
TestIds.dialogs.missingModelsGroup
|
||||
)
|
||||
await expectReferenceBadge(missingModelGroup, 2)
|
||||
await expect(missingModelGroup).toContainText(/\(2\)/)
|
||||
|
||||
const node1 = await comfyPage.nodeOps.getNodeRefById('1')
|
||||
await node1.click('title')
|
||||
await expect(getModelLabel(missingModelGroup)).toBeVisible()
|
||||
await expect(
|
||||
missingModelGroup.getByTestId(TestIds.dialogs.missingModelLocate)
|
||||
).toHaveCount(1)
|
||||
await expect(missingModelGroup).toContainText(/\(1\)/)
|
||||
|
||||
await comfyPage.canvas.click()
|
||||
await expectReferenceBadge(missingModelGroup, 2)
|
||||
await expect(missingModelGroup).toContainText(/\(2\)/)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -396,7 +384,9 @@ test.describe('Errors tab - Mode-aware errors', { tag: '@ui' }, () => {
|
||||
const missingModelGroup = comfyPage.page.getByTestId(
|
||||
TestIds.dialogs.missingModelsGroup
|
||||
)
|
||||
await expect(getModelLabel(missingModelGroup)).toBeVisible()
|
||||
await expect(missingModelGroup).toContainText(
|
||||
/fake_model\.safetensors\s*\(1\)/
|
||||
)
|
||||
|
||||
await comfyPage.page.evaluate((value) => {
|
||||
const hostNode = window.app!.graph!.getNodeById(2)
|
||||
@@ -449,7 +439,9 @@ test.describe('Errors tab - Mode-aware errors', { tag: '@ui' }, () => {
|
||||
const missingModelGroup = comfyPage.page.getByTestId(
|
||||
TestIds.dialogs.missingModelsGroup
|
||||
)
|
||||
await expect(getModelLabel(missingModelGroup)).toBeVisible()
|
||||
await expect(missingModelGroup).toContainText(
|
||||
/fake_model\.safetensors\s*\(1\)/
|
||||
)
|
||||
|
||||
const promotedModelCombo = comfyPage.vueNodes
|
||||
.getNodeByTitle('Subgraph with Promoted Missing Model')
|
||||
|
||||
@@ -233,4 +233,64 @@ test.describe('Model library sidebar - empty state', () => {
|
||||
await expect(tab.folderNodes).toHaveCount(0)
|
||||
await expect(tab.leafNodes).toHaveCount(0)
|
||||
})
|
||||
|
||||
test.describe('Model library sidebar - add node', () => {
|
||||
test.beforeEach(async ({ comfyPage }) => {
|
||||
await comfyPage.modelLibrary.mockFoldersWithFiles(MOCK_FOLDERS)
|
||||
await comfyPage.setup()
|
||||
await comfyPage.nodeOps.clearGraph()
|
||||
})
|
||||
|
||||
test.afterEach(async ({ comfyPage }) => {
|
||||
await comfyPage.modelLibrary.clearMocks()
|
||||
})
|
||||
|
||||
test('Clicking a model defers creation until placed on the canvas', async ({
|
||||
comfyPage
|
||||
}) => {
|
||||
const tab = comfyPage.menu.modelLibraryTab
|
||||
await tab.open()
|
||||
await tab.getFolderByLabel('checkpoints').click()
|
||||
await expect(tab.getLeafByLabel('sd_xl_base_1.0')).toBeVisible()
|
||||
|
||||
await tab.getLeafByLabel('sd_xl_base_1.0').click()
|
||||
|
||||
await expect
|
||||
.poll(() => comfyPage.nodeOps.getGraphNodesCount(), { timeout: 1000 })
|
||||
.toBe(0)
|
||||
|
||||
const canvasBox = (await comfyPage.canvas.boundingBox())!
|
||||
await comfyPage.canvas.click({
|
||||
position: { x: canvasBox.width / 2, y: canvasBox.height / 2 }
|
||||
})
|
||||
|
||||
await expect.poll(() => comfyPage.nodeOps.getGraphNodesCount()).toBe(1)
|
||||
await expect
|
||||
.poll(() => comfyPage.nodeOps.getSelectedGraphNodesCount())
|
||||
.toBe(1)
|
||||
|
||||
const [loader] = await comfyPage.nodeOps.getNodeRefsByType(
|
||||
'CheckpointLoaderSimple'
|
||||
)
|
||||
expect(loader).toBeDefined()
|
||||
const widget = await loader.getWidgetByName('ckpt_name')
|
||||
expect(await widget.getValue()).toBe('sd_xl_base_1.0.safetensors')
|
||||
})
|
||||
|
||||
test('Ghost preview shows the model in the loader widget before placing', async ({
|
||||
comfyPage
|
||||
}) => {
|
||||
const tab = comfyPage.menu.modelLibraryTab
|
||||
await tab.open()
|
||||
await tab.getFolderByLabel('checkpoints').click()
|
||||
await expect(tab.getLeafByLabel('sd_xl_base_1.0')).toBeVisible()
|
||||
|
||||
await tab.getLeafByLabel('sd_xl_base_1.0').click()
|
||||
|
||||
const ghost = comfyPage.page.locator(
|
||||
'[data-node-id="preview-CheckpointLoaderSimple"]'
|
||||
)
|
||||
await expect(ghost).toContainText('sd_xl_base_1.0.safetensors')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,101 +0,0 @@
|
||||
import { expect } from '@playwright/test'
|
||||
|
||||
import type { RemoteConfig } from '@/platform/remoteConfig/types'
|
||||
import type { WorkspaceWithRole } from '@/platform/workspace/api/workspaceApi'
|
||||
import type { WorkspaceTokenResponse } from '@/platform/workspace/stores/workspaceAuthStore'
|
||||
import { comfyPageFixture } from '@e2e/fixtures/ComfyPage'
|
||||
|
||||
const PERSONAL_WORKSPACE_NAME = 'Personal Workspace'
|
||||
const LONG_WORKSPACE_NAME =
|
||||
'Quantum Renaissance Collective for Hyperdimensional Latent Diffusion Research and Experimental Workflow Engineering'
|
||||
|
||||
// text-sm rows render a single 20px line; a wrapped name is 40px+.
|
||||
const SINGLE_LINE_MAX_HEIGHT_PX = 28
|
||||
|
||||
const mockRemoteConfig: RemoteConfig = { team_workspaces_enabled: true }
|
||||
|
||||
const mockListWorkspacesResponse: { workspaces: WorkspaceWithRole[] } = {
|
||||
workspaces: [
|
||||
{
|
||||
id: 'ws-personal',
|
||||
name: PERSONAL_WORKSPACE_NAME,
|
||||
type: 'personal',
|
||||
created_at: '2026-01-01T00:00:00Z',
|
||||
joined_at: '2026-01-01T00:00:00Z',
|
||||
role: 'owner'
|
||||
},
|
||||
{
|
||||
id: 'ws-team-long',
|
||||
name: LONG_WORKSPACE_NAME,
|
||||
type: 'team',
|
||||
created_at: '2026-01-02T00:00:00Z',
|
||||
joined_at: '2026-01-02T00:00:00Z',
|
||||
role: 'member'
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
const mockTokenResponse: WorkspaceTokenResponse = {
|
||||
token: 'mock-workspace-token',
|
||||
expires_at: new Date(Date.now() + 60 * 60 * 1000).toISOString(),
|
||||
workspace: {
|
||||
id: 'ws-personal',
|
||||
name: PERSONAL_WORKSPACE_NAME,
|
||||
type: 'personal'
|
||||
},
|
||||
role: 'owner',
|
||||
permissions: []
|
||||
}
|
||||
|
||||
const test = comfyPageFixture.extend({
|
||||
page: async ({ page }, use) => {
|
||||
await page.route('**/api/features', (route) =>
|
||||
route.fulfill({
|
||||
status: 200,
|
||||
contentType: 'application/json',
|
||||
body: JSON.stringify(mockRemoteConfig)
|
||||
})
|
||||
)
|
||||
|
||||
await page.route('**/api/workspaces', async (route) => {
|
||||
if (route.request().method() !== 'GET') {
|
||||
await route.fallback()
|
||||
return
|
||||
}
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
contentType: 'application/json',
|
||||
body: JSON.stringify(mockListWorkspacesResponse)
|
||||
})
|
||||
})
|
||||
|
||||
await page.route('**/api/auth/token', (route) =>
|
||||
route.fulfill({
|
||||
status: 200,
|
||||
contentType: 'application/json',
|
||||
body: JSON.stringify(mockTokenResponse)
|
||||
})
|
||||
)
|
||||
|
||||
await use(page)
|
||||
}
|
||||
})
|
||||
|
||||
test.describe('Workspace switcher', { tag: '@cloud' }, () => {
|
||||
test('renders a long team workspace name on a single line', async ({
|
||||
comfyPage
|
||||
}) => {
|
||||
const page = comfyPage.page
|
||||
|
||||
await comfyPage.toast.closeToasts()
|
||||
await page.getByRole('button', { name: 'Current user' }).click()
|
||||
await page.getByText(PERSONAL_WORKSPACE_NAME).click()
|
||||
|
||||
const longName = page.getByText(LONG_WORKSPACE_NAME)
|
||||
await expect(longName).toBeVisible()
|
||||
|
||||
const box = await longName.boundingBox()
|
||||
expect(box).not.toBeNull()
|
||||
expect(box!.height).toBeLessThan(SINGLE_LINE_MAX_HEIGHT_PX)
|
||||
})
|
||||
})
|
||||
@@ -111,7 +111,6 @@ describe('formatUtil', () => {
|
||||
expect(getMediaTypeFromFilename('scene.fbx')).toBe('3D')
|
||||
expect(getMediaTypeFromFilename('asset.gltf')).toBe('3D')
|
||||
expect(getMediaTypeFromFilename('binary.glb')).toBe('3D')
|
||||
expect(getMediaTypeFromFilename('print.stl')).toBe('3D')
|
||||
expect(getMediaTypeFromFilename('apple.usdz')).toBe('3D')
|
||||
expect(getMediaTypeFromFilename('scan.ply')).toBe('3D')
|
||||
})
|
||||
|
||||
@@ -591,15 +591,7 @@ const IMAGE_EXTENSIONS = [
|
||||
] as const
|
||||
const VIDEO_EXTENSIONS = ['mp4', 'm4v', 'webm', 'mov', 'avi', 'mkv'] as const
|
||||
const AUDIO_EXTENSIONS = ['mp3', 'wav', 'ogg', 'flac'] as const
|
||||
const THREE_D_EXTENSIONS = [
|
||||
'obj',
|
||||
'fbx',
|
||||
'gltf',
|
||||
'glb',
|
||||
'stl',
|
||||
'usdz',
|
||||
'ply'
|
||||
] as const
|
||||
const THREE_D_EXTENSIONS = ['obj', 'fbx', 'gltf', 'glb', 'usdz', 'ply'] as const
|
||||
const TEXT_EXTENSIONS = [
|
||||
'txt',
|
||||
'md',
|
||||
|
||||
|
Before Width: | Height: | Size: 15 KiB After Width: | Height: | Size: 15 KiB |
@@ -355,7 +355,7 @@ describe('TreeExplorerV2Node', () => {
|
||||
const nodeDiv = getTreeNode(container)
|
||||
await fireEvent.dragStart(nodeDiv)
|
||||
|
||||
expect(mockStartDrag).toHaveBeenCalledWith(mockData, 'native')
|
||||
expect(mockStartDrag).toHaveBeenCalledWith(mockData, { mode: 'native' })
|
||||
})
|
||||
|
||||
it('does not call startDrag for folder items on dragstart', async () => {
|
||||
|
||||
@@ -33,6 +33,7 @@ const {
|
||||
items,
|
||||
gridStyle,
|
||||
bufferRows = 1,
|
||||
scrollThrottle = 64,
|
||||
resizeDebounce = 64,
|
||||
defaultItemHeight = 200,
|
||||
defaultItemWidth = 200,
|
||||
@@ -41,6 +42,7 @@ const {
|
||||
items: (T & { key: string })[]
|
||||
gridStyle: CSSProperties
|
||||
bufferRows?: number
|
||||
scrollThrottle?: number
|
||||
resizeDebounce?: number
|
||||
defaultItemHeight?: number
|
||||
defaultItemWidth?: number
|
||||
@@ -59,6 +61,7 @@ const itemWidth = ref(defaultItemWidth)
|
||||
const container = ref<HTMLElement | null>(null)
|
||||
const { width, height } = useElementSize(container)
|
||||
const { y: scrollY } = useScroll(container, {
|
||||
throttle: scrollThrottle,
|
||||
eventListenerOptions: { passive: true }
|
||||
})
|
||||
|
||||
|
||||
@@ -93,6 +93,7 @@
|
||||
|
||||
<NodeTooltip v-if="tooltipEnabled" />
|
||||
<NodeSearchboxPopover ref="nodeSearchboxPopoverRef" />
|
||||
<NodeDragPreview />
|
||||
<VueNodeSwitchPopup />
|
||||
|
||||
<!-- Initialize components after comfyApp is ready. useAbsolutePosition requires
|
||||
@@ -136,6 +137,7 @@ import GraphCanvasMenu from '@/components/graph/GraphCanvasMenu.vue'
|
||||
import LinkOverlayCanvas from '@/components/graph/LinkOverlayCanvas.vue'
|
||||
import NodeTooltip from '@/components/graph/NodeTooltip.vue'
|
||||
import NodeContextMenu from '@/components/graph/NodeContextMenu.vue'
|
||||
import NodeDragPreview from '@/components/graph/NodeDragPreview.vue'
|
||||
import SelectionToolbox from '@/components/graph/SelectionToolbox.vue'
|
||||
import TitleEditor from '@/components/graph/TitleEditor.vue'
|
||||
import NodePropertiesPanel from '@/components/rightSidePanel/RightSidePanel.vue'
|
||||
|
||||
97
src/components/graph/NodeDragPreview.test.ts
Normal file
@@ -0,0 +1,97 @@
|
||||
import { render } from '@testing-library/vue'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { nextTick } from 'vue'
|
||||
|
||||
import NodeDragPreview from '@/components/graph/NodeDragPreview.vue'
|
||||
import { useNodeDragToCanvas } from '@/composables/node/useNodeDragToCanvas'
|
||||
import type { ComfyNodeDefImpl } from '@/stores/nodeDefStore'
|
||||
import { fromPartial } from '@total-typescript/shoehorn'
|
||||
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/components/LGraphNodePreview.vue',
|
||||
() => ({
|
||||
default: { template: '<div data-testid="node-preview" />' }
|
||||
})
|
||||
)
|
||||
|
||||
const nodeDef = fromPartial<ComfyNodeDefImpl>({ name: 'TestNode' })
|
||||
|
||||
function moveMouse(clientX: number, clientY: number) {
|
||||
window.dispatchEvent(new MouseEvent('mousemove', { clientX, clientY }))
|
||||
}
|
||||
|
||||
function ghostElement() {
|
||||
return document.querySelector('[data-testid="node-preview"]')?.parentElement
|
||||
?.parentElement
|
||||
}
|
||||
|
||||
describe('NodeDragPreview', () => {
|
||||
beforeEach(() => {
|
||||
vi.useFakeTimers()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
useNodeDragToCanvas().cancelDrag()
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
it('shows no ghost when nothing is being dragged', async () => {
|
||||
render(NodeDragPreview)
|
||||
|
||||
moveMouse(100, 200)
|
||||
vi.advanceTimersByTime(16)
|
||||
await nextTick()
|
||||
|
||||
expect(ghostElement()).toBeFalsy()
|
||||
})
|
||||
|
||||
it('keeps the ghost hidden until the mouse position is known', async () => {
|
||||
render(NodeDragPreview)
|
||||
|
||||
useNodeDragToCanvas().startDrag(nodeDef)
|
||||
await nextTick()
|
||||
vi.advanceTimersByTime(16)
|
||||
await nextTick()
|
||||
|
||||
expect(ghostElement()).toBeFalsy()
|
||||
})
|
||||
|
||||
it('follows the mouse with an offset while dragging', async () => {
|
||||
render(NodeDragPreview)
|
||||
|
||||
useNodeDragToCanvas().startDrag(nodeDef)
|
||||
await nextTick()
|
||||
moveMouse(100, 200)
|
||||
vi.advanceTimersByTime(16)
|
||||
await nextTick()
|
||||
|
||||
expect(ghostElement()?.style.transform).toBe('translate(112px, 212px)')
|
||||
|
||||
vi.advanceTimersByTime(16)
|
||||
await nextTick()
|
||||
|
||||
expect(ghostElement()?.style.transform).toBe('translate(112px, 212px)')
|
||||
|
||||
moveMouse(300, 400)
|
||||
vi.advanceTimersByTime(16)
|
||||
await nextTick()
|
||||
|
||||
expect(ghostElement()?.style.transform).toBe('translate(312px, 412px)')
|
||||
})
|
||||
|
||||
it('removes the ghost when the drag is cancelled', async () => {
|
||||
render(NodeDragPreview)
|
||||
|
||||
useNodeDragToCanvas().startDrag(nodeDef)
|
||||
await nextTick()
|
||||
moveMouse(100, 200)
|
||||
vi.advanceTimersByTime(16)
|
||||
await nextTick()
|
||||
expect(ghostElement()).toBeTruthy()
|
||||
|
||||
useNodeDragToCanvas().cancelDrag()
|
||||
await nextTick()
|
||||
|
||||
expect(ghostElement()).toBeFalsy()
|
||||
})
|
||||
})
|
||||
57
src/components/graph/NodeDragPreview.vue
Normal file
@@ -0,0 +1,57 @@
|
||||
<template>
|
||||
<Teleport to="body">
|
||||
<div
|
||||
v-if="showGhost && rafPosition"
|
||||
class="pointer-events-none fixed top-0 left-0 z-10000 will-change-transform"
|
||||
:style="{
|
||||
transform: `translate(${rafPosition.x + 12}px, ${rafPosition.y + 12}px)`
|
||||
}"
|
||||
>
|
||||
<div class="origin-top-left scale-50 opacity-80">
|
||||
<LGraphNodePreview
|
||||
:node-def="draggedNode!"
|
||||
:widget-values="pendingWidgetValues"
|
||||
position="relative"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</Teleport>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { useMouse, useRafFn } from '@vueuse/core'
|
||||
import { computed, shallowRef, watch } from 'vue'
|
||||
|
||||
import { useNodeDragToCanvas } from '@/composables/node/useNodeDragToCanvas'
|
||||
import LGraphNodePreview from '@/renderer/extensions/vueNodes/components/LGraphNodePreview.vue'
|
||||
|
||||
const { isDragging, draggedNode, pendingWidgetValues } = useNodeDragToCanvas()
|
||||
|
||||
const { x, y, sourceType } = useMouse({ type: 'client' })
|
||||
|
||||
const showGhost = computed(() => Boolean(isDragging.value && draggedNode.value))
|
||||
const rafPosition = shallowRef<{ x: number; y: number }>()
|
||||
|
||||
const { pause, resume } = useRafFn(
|
||||
() => {
|
||||
if (sourceType.value === null) return
|
||||
const pos = rafPosition.value
|
||||
if (pos && pos.x === x.value && pos.y === y.value) return
|
||||
rafPosition.value = { x: x.value, y: y.value }
|
||||
},
|
||||
{ immediate: false }
|
||||
)
|
||||
|
||||
watch(
|
||||
showGhost,
|
||||
(show) => {
|
||||
if (show) {
|
||||
resume()
|
||||
} else {
|
||||
pause()
|
||||
rafPosition.value = undefined
|
||||
}
|
||||
},
|
||||
{ immediate: true }
|
||||
)
|
||||
</script>
|
||||
@@ -23,8 +23,6 @@
|
||||
:can-use-gizmo="canUseGizmo"
|
||||
:can-use-lighting="canUseLighting"
|
||||
:can-export="canExport"
|
||||
:can-use-hdri="canUseHdri"
|
||||
:can-use-background-image="canUseBackgroundImage"
|
||||
:material-modes="materialModes"
|
||||
:has-skeleton="hasSkeleton"
|
||||
@update-background-image="handleBackgroundImageUpdate"
|
||||
@@ -88,7 +86,7 @@
|
||||
/>
|
||||
|
||||
<RecordingControls
|
||||
v-if="canUseRecording && !isPreview"
|
||||
v-if="!isPreview"
|
||||
v-model:is-recording="isRecording"
|
||||
v-model:has-recording="hasRecording"
|
||||
v-model:recording-duration="recordingDuration"
|
||||
@@ -119,18 +117,9 @@ import { resolveNode } from '@/utils/litegraphUtil'
|
||||
import type { ComponentWidget } from '@/scripts/domWidget'
|
||||
import type { SimplifiedWidget } from '@/types/simplifiedWidget'
|
||||
|
||||
const {
|
||||
widget,
|
||||
nodeId,
|
||||
canUseRecording = true,
|
||||
canUseHdri = true,
|
||||
canUseBackgroundImage = true
|
||||
} = defineProps<{
|
||||
const props = defineProps<{
|
||||
widget: ComponentWidget<string[]> | SimplifiedWidget
|
||||
nodeId?: NodeId
|
||||
canUseRecording?: boolean
|
||||
canUseHdri?: boolean
|
||||
canUseBackgroundImage?: boolean
|
||||
}>()
|
||||
|
||||
function isComponentWidget(
|
||||
@@ -141,11 +130,11 @@ function isComponentWidget(
|
||||
|
||||
const node = ref<LGraphNode | null>(null)
|
||||
|
||||
if (isComponentWidget(widget)) {
|
||||
node.value = widget.node
|
||||
} else if (nodeId) {
|
||||
if (isComponentWidget(props.widget)) {
|
||||
node.value = props.widget.node
|
||||
} else if (props.nodeId) {
|
||||
onMounted(() => {
|
||||
node.value = resolveNode(nodeId) ?? null
|
||||
node.value = resolveNode(props.nodeId!) ?? null
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
import { render } from '@testing-library/vue'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
import { defineComponent, h, ref } from 'vue'
|
||||
|
||||
const lastProps = ref<Record<string, unknown> | null>(null)
|
||||
|
||||
vi.mock('@/components/load3d/Load3D.vue', () => ({
|
||||
default: defineComponent({
|
||||
name: 'Load3D',
|
||||
props: {
|
||||
widget: { type: null, required: false, default: undefined },
|
||||
nodeId: { type: null, required: false, default: undefined },
|
||||
canUseRecording: { type: Boolean, default: true },
|
||||
canUseHdri: { type: Boolean, default: true },
|
||||
canUseBackgroundImage: { type: Boolean, default: true }
|
||||
},
|
||||
setup(props: Record<string, unknown>) {
|
||||
lastProps.value = { ...props }
|
||||
return () => h('div', { 'data-testid': 'load3d-stub' })
|
||||
}
|
||||
})
|
||||
}))
|
||||
|
||||
import Load3DAdvanced from '@/components/load3d/Load3DAdvanced.vue'
|
||||
|
||||
describe('Load3DAdvanced', () => {
|
||||
it('renders the inner Load3D with all expressive features disabled', () => {
|
||||
const MOCK_NODE = { id: 'node', type: 'Load3DAdvanced' }
|
||||
render(Load3DAdvanced, {
|
||||
props: {
|
||||
widget: { node: MOCK_NODE } as never
|
||||
}
|
||||
})
|
||||
expect(lastProps.value).toMatchObject({
|
||||
canUseRecording: false,
|
||||
canUseHdri: false,
|
||||
canUseBackgroundImage: false
|
||||
})
|
||||
})
|
||||
|
||||
it('forwards widget and nodeId to the inner Load3D', () => {
|
||||
const widget = { node: { id: 'a', type: 'Load3DAdvanced' } }
|
||||
render(Load3DAdvanced, { props: { widget: widget as never, nodeId: 'a' } })
|
||||
expect(lastProps.value?.widget).toEqual(widget)
|
||||
expect(lastProps.value?.nodeId).toBe('a')
|
||||
})
|
||||
})
|
||||
@@ -1,21 +0,0 @@
|
||||
<template>
|
||||
<Load3D
|
||||
:widget="widget"
|
||||
:node-id="nodeId"
|
||||
:can-use-recording="false"
|
||||
:can-use-hdri="false"
|
||||
:can-use-background-image="false"
|
||||
/>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import Load3D from '@/components/load3d/Load3D.vue'
|
||||
import type { NodeId } from '@/platform/workflow/validation/schemas/workflowSchema'
|
||||
import type { ComponentWidget } from '@/scripts/domWidget'
|
||||
import type { SimplifiedWidget } from '@/types/simplifiedWidget'
|
||||
|
||||
defineProps<{
|
||||
widget: ComponentWidget<string[]> | SimplifiedWidget
|
||||
nodeId?: NodeId
|
||||
}>()
|
||||
</script>
|
||||
@@ -52,7 +52,6 @@
|
||||
v-model:background-image="sceneConfig!.backgroundImage"
|
||||
v-model:background-render-mode="sceneConfig!.backgroundRenderMode"
|
||||
v-model:fov="cameraConfig!.fov"
|
||||
:show-background-image="canUseBackgroundImage"
|
||||
:hdri-active="
|
||||
!!lightConfig?.hdri?.hdriPath && !!lightConfig?.hdri?.enabled
|
||||
"
|
||||
@@ -82,7 +81,6 @@
|
||||
/>
|
||||
|
||||
<HDRIControls
|
||||
v-if="canUseHdri"
|
||||
v-model:hdri-config="lightConfig!.hdri"
|
||||
:has-background-image="!!sceneConfig?.backgroundImage"
|
||||
@update-hdri-file="handleHDRIFileUpdate"
|
||||
@@ -131,16 +129,12 @@ const {
|
||||
canUseGizmo = true,
|
||||
canUseLighting = true,
|
||||
canExport = true,
|
||||
canUseHdri = true,
|
||||
canUseBackgroundImage = true,
|
||||
materialModes = ['original', 'normal', 'wireframe'],
|
||||
hasSkeleton = false
|
||||
} = defineProps<{
|
||||
canUseGizmo?: boolean
|
||||
canUseLighting?: boolean
|
||||
canExport?: boolean
|
||||
canUseHdri?: boolean
|
||||
canUseBackgroundImage?: boolean
|
||||
materialModes?: readonly MaterialMode[]
|
||||
hasSkeleton?: boolean
|
||||
}>()
|
||||
|
||||
@@ -37,7 +37,7 @@
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<div v-if="showBackgroundImage && !hasBackgroundImage">
|
||||
<div v-if="!hasBackgroundImage">
|
||||
<Button
|
||||
v-tooltip.right="{
|
||||
value: $t('load3d.uploadBackgroundImage'),
|
||||
@@ -61,7 +61,7 @@
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<div v-if="showBackgroundImage && hasBackgroundImage">
|
||||
<div v-if="hasBackgroundImage">
|
||||
<Button
|
||||
v-tooltip.right="{
|
||||
value: $t('load3d.panoramaMode'),
|
||||
@@ -83,16 +83,12 @@
|
||||
</div>
|
||||
|
||||
<PopupSlider
|
||||
v-if="
|
||||
showBackgroundImage &&
|
||||
hasBackgroundImage &&
|
||||
backgroundRenderMode === 'panorama'
|
||||
"
|
||||
v-if="hasBackgroundImage && backgroundRenderMode === 'panorama'"
|
||||
v-model="fov"
|
||||
:tooltip-text="$t('load3d.fov')"
|
||||
/>
|
||||
|
||||
<div v-if="showBackgroundImage && hasBackgroundImage">
|
||||
<div v-if="hasBackgroundImage">
|
||||
<Button
|
||||
v-tooltip.right="{
|
||||
value: $t('load3d.removeBackgroundImage'),
|
||||
@@ -118,9 +114,8 @@ import Button from '@/components/ui/button/Button.vue'
|
||||
import type { BackgroundRenderModeType } from '@/extensions/core/load3d/interfaces'
|
||||
import { cn } from '@comfyorg/tailwind-utils'
|
||||
|
||||
const { hdriActive = false, showBackgroundImage = true } = defineProps<{
|
||||
const { hdriActive = false } = defineProps<{
|
||||
hdriActive?: boolean
|
||||
showBackgroundImage?: boolean
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
|
||||
@@ -530,16 +530,14 @@ describe('TabErrors.vue', () => {
|
||||
expect(
|
||||
screen.getByText('Some nodes can be replaced with alternatives')
|
||||
).toBeInTheDocument()
|
||||
expect(
|
||||
screen.getByRole('button', { name: 'OldSampler' })
|
||||
).toBeInTheDocument()
|
||||
expect(screen.getByText('OldSampler (1)')).toBeInTheDocument()
|
||||
expect(screen.getByText('KSampler')).toBeInTheDocument()
|
||||
expect(
|
||||
screen.getByRole('button', { name: /Replace Node/ })
|
||||
).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('renders missing model Refresh in the header and Download all in the card when models are downloadable', () => {
|
||||
it('keeps missing model Refresh in the card actions when models are downloadable', () => {
|
||||
const missingModel = {
|
||||
nodeId: '1',
|
||||
nodeType: 'CheckpointLoaderSimple',
|
||||
@@ -557,8 +555,11 @@ describe('TabErrors.vue', () => {
|
||||
}
|
||||
})
|
||||
|
||||
expect(screen.getByTestId('missing-model-header-refresh')).toBeVisible()
|
||||
expect(
|
||||
screen.queryByTestId('missing-model-header-refresh')
|
||||
).not.toBeInTheDocument()
|
||||
expect(screen.getByTestId('missing-model-actions')).toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: /Download all/ })).toBeVisible()
|
||||
expect(screen.getByRole('button', { name: 'Refresh' })).toBeVisible()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -94,10 +94,9 @@
|
||||
showMissingModelHeaderRefresh
|
||||
"
|
||||
data-testid="missing-model-header-refresh"
|
||||
variant="muted-textonly"
|
||||
size="icon"
|
||||
class="mr-2 shrink-0 rounded-lg hover:bg-transparent hover:text-base-foreground"
|
||||
:aria-label="t('rightSidePanel.missingModels.refresh')"
|
||||
variant="secondary"
|
||||
size="sm"
|
||||
class="mr-2 h-8 shrink-0 rounded-lg text-sm"
|
||||
:aria-busy="missingModelStore.isRefreshingMissingModels"
|
||||
:aria-disabled="missingModelStore.isRefreshingMissingModels"
|
||||
@click.stop="handleMissingModelRefresh"
|
||||
@@ -113,6 +112,7 @@
|
||||
aria-hidden="true"
|
||||
class="icon-[lucide--refresh-cw] size-4 shrink-0"
|
||||
/>
|
||||
{{ t('rightSidePanel.missingModels.refresh') }}
|
||||
</Button>
|
||||
<span
|
||||
v-if="
|
||||
@@ -157,6 +157,7 @@
|
||||
<SwapNodesCard
|
||||
v-if="group.type === 'swap_nodes'"
|
||||
:swap-node-groups="swapNodeGroups"
|
||||
:show-node-id-badge="showNodeIdBadge"
|
||||
@locate-node="handleLocateMissingNode"
|
||||
@replace="handleReplaceGroup"
|
||||
/>
|
||||
@@ -246,6 +247,7 @@
|
||||
<MissingModelCard
|
||||
v-if="group.type === 'missing_model'"
|
||||
:missing-model-groups="missingModelGroups"
|
||||
:show-node-id-badge="showNodeIdBadge"
|
||||
@locate-model="handleLocateAssetNode"
|
||||
/>
|
||||
|
||||
@@ -300,9 +302,11 @@ import { cn } from '@comfyorg/tailwind-utils'
|
||||
|
||||
import { useCopyToClipboard } from '@/composables/useCopyToClipboard'
|
||||
import { useFocusNode } from '@/composables/canvas/useFocusNode'
|
||||
import { useSettingStore } from '@/platform/settings/settingStore'
|
||||
import { useRightSidePanelStore } from '@/stores/workspace/rightSidePanelStore'
|
||||
import { useManagerState } from '@/workbench/extensions/manager/composables/useManagerState'
|
||||
import { ManagerTab } from '@/workbench/extensions/manager/types/comfyManagerTypes'
|
||||
import { NodeBadgeMode } from '@/types/nodeSource'
|
||||
|
||||
import PropertiesAccordionItem from '../layout/PropertiesAccordionItem.vue'
|
||||
import CollapseToggleButton from '../layout/CollapseToggleButton.vue'
|
||||
@@ -316,6 +320,7 @@ import MissingMediaCard from '@/platform/missingMedia/components/MissingMediaCar
|
||||
import { isCloud, isDesktop, isNightly } from '@/platform/distribution/types'
|
||||
import Button from '@/components/ui/button/Button.vue'
|
||||
import DotSpinner from '@/components/common/DotSpinner.vue'
|
||||
import { getDownloadableModels } from '@/platform/missingModel/missingModelViewUtils'
|
||||
import { useMissingModelStore } from '@/platform/missingModel/missingModelStore'
|
||||
import { usePackInstall } from '@/workbench/extensions/manager/composables/nodePack/usePackInstall'
|
||||
import { useMissingNodes } from '@/workbench/extensions/manager/composables/nodePack/useMissingNodes'
|
||||
@@ -343,6 +348,7 @@ const { t } = useI18n()
|
||||
const { copyToClipboard } = useCopyToClipboard()
|
||||
const { focusNode, enterSubgraph } = useFocusNode()
|
||||
const { openGitHubIssues, contactSupport } = useErrorActions()
|
||||
const settingStore = useSettingStore()
|
||||
const rightSidePanelStore = useRightSidePanelStore()
|
||||
const missingModelStore = useMissingModelStore()
|
||||
const { shouldShowManagerButtons, shouldShowInstallButton, openManager } =
|
||||
@@ -366,6 +372,12 @@ function getGroupSize(group: ErrorGroup) {
|
||||
return fullSizeGroupTypes.has(group.type) ? 'lg' : 'default'
|
||||
}
|
||||
|
||||
const showNodeIdBadge = computed(
|
||||
() =>
|
||||
(settingStore.get('Comfy.NodeBadge.NodeIdBadgeMode') as NodeBadgeMode) !==
|
||||
NodeBadgeMode.None
|
||||
)
|
||||
|
||||
function isExecutionItemListGroup(group: ErrorGroup) {
|
||||
return (
|
||||
group.type === 'execution' &&
|
||||
@@ -452,8 +464,17 @@ const {
|
||||
swapNodeGroups
|
||||
} = useErrorGroups(searchQuery)
|
||||
|
||||
const missingModelDownloadableModels = computed(() => {
|
||||
if (isCloud) return []
|
||||
|
||||
return getDownloadableModels(missingModelGroups.value)
|
||||
})
|
||||
|
||||
const showMissingModelHeaderRefresh = computed(
|
||||
() => !isCloud && missingModelGroups.value.length > 0
|
||||
() =>
|
||||
!isCloud &&
|
||||
missingModelGroups.value.length > 0 &&
|
||||
missingModelDownloadableModels.value.length === 0
|
||||
)
|
||||
|
||||
function handleMissingModelRefresh() {
|
||||
|
||||
@@ -14,7 +14,7 @@ const {
|
||||
captureRoot,
|
||||
getRoot,
|
||||
resetRoot,
|
||||
mockAddNodeOnGraph,
|
||||
mockStartDrag,
|
||||
mockGetNodeProvider,
|
||||
mockToggleNodeOnEvent,
|
||||
mockRefreshModelFolder,
|
||||
@@ -29,7 +29,7 @@ const {
|
||||
resetRoot: () => {
|
||||
capturedRoot = null
|
||||
},
|
||||
mockAddNodeOnGraph: vi.fn(),
|
||||
mockStartDrag: vi.fn(),
|
||||
mockGetNodeProvider: vi.fn(),
|
||||
mockToggleNodeOnEvent: vi.fn(),
|
||||
mockRefreshModelFolder: vi.fn().mockResolvedValue(undefined),
|
||||
@@ -37,8 +37,8 @@ const {
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/services/litegraphService', () => ({
|
||||
useLitegraphService: () => ({ addNodeOnGraph: mockAddNodeOnGraph })
|
||||
vi.mock('@/composables/node/useNodeDragToCanvas', () => ({
|
||||
useNodeDragToCanvas: () => ({ startDrag: mockStartDrag })
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/modelToNodeStore', () => ({
|
||||
@@ -173,16 +173,13 @@ describe('ModelLibrarySidebarTab', () => {
|
||||
expect(screen.getByTestId('search-input')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('handles model click and adds node to graph', async () => {
|
||||
it('starts a ghost drag carrying the widget value to fill on placement', async () => {
|
||||
const mockNodeDef = { name: 'CheckpointLoaderSimple' }
|
||||
const mockWidget = { name: 'ckpt_name', value: '' }
|
||||
const mockGraphNode = { widgets: [mockWidget] }
|
||||
|
||||
mockGetNodeProvider.mockReturnValue({
|
||||
nodeDef: mockNodeDef,
|
||||
key: 'ckpt_name'
|
||||
})
|
||||
mockAddNodeOnGraph.mockReturnValue(mockGraphNode)
|
||||
|
||||
renderComponent()
|
||||
await nextTick()
|
||||
@@ -198,8 +195,9 @@ describe('ModelLibrarySidebarTab', () => {
|
||||
await modelLeaf?.handleClick?.(mockEvent)
|
||||
|
||||
expect(mockGetNodeProvider).toHaveBeenCalledWith('checkpoints')
|
||||
expect(mockAddNodeOnGraph).toHaveBeenCalledWith(mockNodeDef)
|
||||
expect(mockWidget.value).toBe('model.safetensors')
|
||||
expect(mockStartDrag).toHaveBeenCalledWith(mockNodeDef, {
|
||||
widgetValues: { ckpt_name: 'model.safetensors' }
|
||||
})
|
||||
})
|
||||
|
||||
it('toggles folder expansion on click', async () => {
|
||||
|
||||
@@ -63,10 +63,9 @@ import SidebarTabTemplate from '@/components/sidebar/tabs/SidebarTabTemplate.vue
|
||||
import ElectronDownloadItems from '@/components/sidebar/tabs/modelLibrary/ElectronDownloadItems.vue'
|
||||
import ModelTreeLeaf from '@/components/sidebar/tabs/modelLibrary/ModelTreeLeaf.vue'
|
||||
import Button from '@/components/ui/button/Button.vue'
|
||||
import { startModelLoaderDrag } from '@/composables/node/startModelNodeDragFromAsset'
|
||||
import { useTreeExpansion } from '@/composables/useTreeExpansion'
|
||||
import { useSettingStore } from '@/platform/settings/settingStore'
|
||||
import { withNodeAddSource } from '@/platform/telemetry/nodeAdded/nodeAddSource'
|
||||
import { useLitegraphService } from '@/services/litegraphService'
|
||||
import { useAssetDownloadStore } from '@/stores/assetDownloadStore'
|
||||
import type { ComfyModelDef, ModelFolder } from '@/stores/modelStore'
|
||||
import { ResourceState, useModelStore } from '@/stores/modelStore'
|
||||
@@ -156,15 +155,7 @@ const renderedRoot = computed<TreeExplorerNode<ModelOrFolder>>(() => {
|
||||
if (this.leaf && model) {
|
||||
const provider = modelToNodeStore.getNodeProvider(model.directory)
|
||||
if (provider) {
|
||||
const graphNode = withNodeAddSource('sidebar_drag', () =>
|
||||
useLitegraphService().addNodeOnGraph(provider.nodeDef)
|
||||
)
|
||||
const widget = graphNode?.widgets?.find(
|
||||
(widget) => widget.name === provider.key
|
||||
)
|
||||
if (widget) {
|
||||
widget.value = model.file_name
|
||||
}
|
||||
startModelLoaderDrag(provider, model.file_name)
|
||||
}
|
||||
} else {
|
||||
toggleNodeOnEvent(e, node)
|
||||
|
||||
@@ -31,11 +31,8 @@ vi.mock('@/composables/node/useNodeDragToCanvas', () => ({
|
||||
useNodeDragToCanvas: () => ({
|
||||
isDragging: { value: false },
|
||||
draggedNode: { value: null },
|
||||
cursorPosition: { value: { x: 0, y: 0 } },
|
||||
startDrag: vi.fn(),
|
||||
cancelDrag: vi.fn(),
|
||||
setupGlobalListeners: vi.fn(),
|
||||
cleanupGlobalListeners: vi.fn()
|
||||
cancelDrag: vi.fn()
|
||||
})
|
||||
}))
|
||||
|
||||
|
||||
@@ -115,7 +115,6 @@
|
||||
</div>
|
||||
</template>
|
||||
<template #body>
|
||||
<NodeDragPreview />
|
||||
<div class="flex h-full flex-col">
|
||||
<div
|
||||
v-if="hasNoMatches"
|
||||
@@ -215,7 +214,6 @@ import type {
|
||||
import AllNodesPanel from './nodeLibrary/AllNodesPanel.vue'
|
||||
import BlueprintsPanel from './nodeLibrary/BlueprintsPanel.vue'
|
||||
import EssentialNodesPanel from './nodeLibrary/EssentialNodesPanel.vue'
|
||||
import NodeDragPreview from './nodeLibrary/NodeDragPreview.vue'
|
||||
import SidebarTabTemplate from './SidebarTabTemplate.vue'
|
||||
|
||||
const { flags } = useFeatureFlags()
|
||||
|
||||
@@ -1,69 +0,0 @@
|
||||
<template>
|
||||
<Teleport to="body">
|
||||
<div
|
||||
v-if="isDragging && draggedNode && showPreview"
|
||||
class="pointer-events-none fixed z-10000"
|
||||
:style="{
|
||||
left: `${previewPosition.x + 12}px`,
|
||||
top: `${previewPosition.y + 12}px`
|
||||
}"
|
||||
>
|
||||
<div class="origin-top-left scale-50 opacity-80">
|
||||
<LGraphNodePreview :node-def="draggedNode" position="relative" />
|
||||
</div>
|
||||
</div>
|
||||
</Teleport>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed, onMounted, onUnmounted, ref } from 'vue'
|
||||
|
||||
import { useNodeDragToCanvas } from '@/composables/node/useNodeDragToCanvas'
|
||||
import LGraphNodePreview from '@/renderer/extensions/vueNodes/components/LGraphNodePreview.vue'
|
||||
|
||||
const {
|
||||
isDragging,
|
||||
draggedNode,
|
||||
cursorPosition,
|
||||
dragMode,
|
||||
setupGlobalListeners,
|
||||
cleanupGlobalListeners
|
||||
} = useNodeDragToCanvas()
|
||||
|
||||
const nativeDragPosition = ref({ x: 0, y: 0 })
|
||||
|
||||
const previewPosition = computed(() => {
|
||||
if (dragMode.value === 'native') {
|
||||
return nativeDragPosition.value
|
||||
}
|
||||
return cursorPosition.value
|
||||
})
|
||||
|
||||
const showPreview = computed(() => {
|
||||
if (dragMode.value === 'native') {
|
||||
return nativeDragPosition.value.x > 0 || nativeDragPosition.value.y > 0
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
function handleDrag(e: DragEvent) {
|
||||
if (e.clientX === 0 && e.clientY === 0) return
|
||||
nativeDragPosition.value = { x: e.clientX, y: e.clientY }
|
||||
}
|
||||
|
||||
function handleDragEnd() {
|
||||
nativeDragPosition.value = { x: 0, y: 0 }
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
setupGlobalListeners()
|
||||
document.addEventListener('drag', handleDrag)
|
||||
document.addEventListener('dragend', handleDragEnd)
|
||||
})
|
||||
|
||||
onUnmounted(() => {
|
||||
cleanupGlobalListeners()
|
||||
document.removeEventListener('drag', handleDrag)
|
||||
document.removeEventListener('dragend', handleDragEnd)
|
||||
})
|
||||
</script>
|
||||
73
src/composables/node/startModelNodeDragFromAsset.test.ts
Normal file
@@ -0,0 +1,73 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
|
||||
import { startModelNodeDragFromAsset } from '@/composables/node/startModelNodeDragFromAsset'
|
||||
|
||||
const { mockStartDrag, mockGetNodeProvider } = vi.hoisted(() => ({
|
||||
mockStartDrag: vi.fn(),
|
||||
mockGetNodeProvider: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@/composables/node/useNodeDragToCanvas', () => ({
|
||||
useNodeDragToCanvas: () => ({ startDrag: mockStartDrag })
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/modelToNodeStore', () => ({
|
||||
useModelToNodeStore: () => ({ getNodeProvider: mockGetNodeProvider })
|
||||
}))
|
||||
|
||||
function createAsset(overrides: Partial<AssetItem> = {}): AssetItem {
|
||||
return {
|
||||
id: 'asset-123',
|
||||
name: 'sd_xl_base_1.0.safetensors',
|
||||
size: 1024,
|
||||
created_at: '2025-10-01T00:00:00Z',
|
||||
tags: ['models', 'checkpoints'],
|
||||
user_metadata: { filename: 'sd_xl_base_1.0.safetensors' },
|
||||
...overrides
|
||||
}
|
||||
}
|
||||
|
||||
describe('startModelNodeDragFromAsset', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
})
|
||||
|
||||
it('starts a ghost drag for the resolved node carrying the widget value', () => {
|
||||
const nodeDef = { name: 'CheckpointLoaderSimple' }
|
||||
mockGetNodeProvider.mockReturnValue({ nodeDef, key: 'ckpt_name' })
|
||||
|
||||
const error = startModelNodeDragFromAsset(createAsset())
|
||||
|
||||
expect(error).toBeUndefined()
|
||||
expect(mockStartDrag).toHaveBeenCalledWith(nodeDef, {
|
||||
widgetValues: { ckpt_name: 'sd_xl_base_1.0.safetensors' }
|
||||
})
|
||||
})
|
||||
|
||||
it('carries no widget value when the provider has no key', () => {
|
||||
const nodeDef = { name: 'FL_ChatterboxVC' }
|
||||
mockGetNodeProvider.mockReturnValue({ nodeDef, key: '' })
|
||||
|
||||
startModelNodeDragFromAsset(
|
||||
createAsset({
|
||||
tags: ['models', 'chatterbox/chatterbox_vc'],
|
||||
user_metadata: { filename: 'chatterbox_vc_model.pt' }
|
||||
})
|
||||
)
|
||||
|
||||
expect(mockStartDrag).toHaveBeenCalledWith(nodeDef, {
|
||||
widgetValues: undefined
|
||||
})
|
||||
})
|
||||
|
||||
it('returns the resolution error and does not start a drag for an invalid asset', () => {
|
||||
mockGetNodeProvider.mockReturnValue(null)
|
||||
|
||||
const error = startModelNodeDragFromAsset(createAsset())
|
||||
|
||||
expect(error?.code).toBe('NO_PROVIDER')
|
||||
expect(mockStartDrag).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
35
src/composables/node/startModelNodeDragFromAsset.ts
Normal file
@@ -0,0 +1,35 @@
|
||||
import { useNodeDragToCanvas } from '@/composables/node/useNodeDragToCanvas'
|
||||
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
|
||||
import { resolveModelNodeFromAsset } from '@/platform/assets/utils/resolveModelNodeFromAsset'
|
||||
import type { ResolveModelNodeError } from '@/platform/assets/utils/resolveModelNodeFromAsset'
|
||||
import type { ModelNodeProvider } from '@/stores/modelToNodeStore'
|
||||
|
||||
/**
|
||||
* Arms a ghost drag for a model loader node. Providers with no widget key
|
||||
* (auto-load nodes) start the drag without widget values.
|
||||
*/
|
||||
export function startModelLoaderDrag(
|
||||
provider: ModelNodeProvider,
|
||||
filename: string
|
||||
) {
|
||||
const widgetValues = provider.key ? { [provider.key]: filename } : undefined
|
||||
useNodeDragToCanvas().startDrag(provider.nodeDef, { widgetValues })
|
||||
}
|
||||
|
||||
/**
|
||||
* Starts a ghost drag for the model loader node described by an asset. The
|
||||
* node is created where the user next clicks the canvas, with the asset's
|
||||
* filename written into the loader widget.
|
||||
*
|
||||
* @returns the resolution error when the asset cannot be mapped to a node,
|
||||
* otherwise `undefined`.
|
||||
*/
|
||||
export function startModelNodeDragFromAsset(
|
||||
asset: AssetItem
|
||||
): ResolveModelNodeError | undefined {
|
||||
const resolved = resolveModelNodeFromAsset(asset)
|
||||
if (!resolved.success) return resolved.error
|
||||
|
||||
const { provider, filename } = resolved.value
|
||||
startModelLoaderDrag(provider, filename)
|
||||
}
|
||||
@@ -7,7 +7,8 @@ const {
|
||||
mockAddNodeOnGraph,
|
||||
mockConvertEventToCanvasOffset,
|
||||
mockSelectItems,
|
||||
mockCanvas
|
||||
mockCanvas,
|
||||
mockToastAdd
|
||||
} = vi.hoisted(() => {
|
||||
const mockConvertEventToCanvasOffset = vi.fn()
|
||||
const mockSelectItems = vi.fn()
|
||||
@@ -15,6 +16,7 @@ const {
|
||||
mockAddNodeOnGraph: vi.fn(),
|
||||
mockConvertEventToCanvasOffset,
|
||||
mockSelectItems,
|
||||
mockToastAdd: vi.fn(),
|
||||
mockCanvas: {
|
||||
canvas: {
|
||||
getBoundingClientRect: vi.fn()
|
||||
@@ -37,6 +39,12 @@ vi.mock('@/services/litegraphService', () => ({
|
||||
}))
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/updates/common/toastStore', () => ({
|
||||
useToastStore: vi.fn(() => ({ add: mockToastAdd }))
|
||||
}))
|
||||
|
||||
vi.mock('@/i18n', () => ({ t: (key: string) => key }))
|
||||
|
||||
describe('useNodeDragToCanvas', () => {
|
||||
let useNodeDragToCanvas: typeof UseNodeDragToCanvasType
|
||||
|
||||
@@ -54,8 +62,8 @@ describe('useNodeDragToCanvas', () => {
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
const { cleanupGlobalListeners } = useNodeDragToCanvas()
|
||||
cleanupGlobalListeners()
|
||||
const { cancelDrag } = useNodeDragToCanvas()
|
||||
cancelDrag()
|
||||
vi.restoreAllMocks()
|
||||
})
|
||||
|
||||
@@ -71,22 +79,6 @@ describe('useNodeDragToCanvas', () => {
|
||||
expect(isDragging.value).toBe(true)
|
||||
expect(draggedNode.value).toBe(mockNodeDef)
|
||||
})
|
||||
|
||||
it('should set dragMode to click by default', () => {
|
||||
const { dragMode, startDrag } = useNodeDragToCanvas()
|
||||
|
||||
startDrag(mockNodeDef)
|
||||
|
||||
expect(dragMode.value).toBe('click')
|
||||
})
|
||||
|
||||
it('should set dragMode to native when specified', () => {
|
||||
const { dragMode, startDrag } = useNodeDragToCanvas()
|
||||
|
||||
startDrag(mockNodeDef, 'native')
|
||||
|
||||
expect(dragMode.value).toBe('native')
|
||||
})
|
||||
})
|
||||
|
||||
describe('cancelDrag', () => {
|
||||
@@ -102,30 +94,15 @@ describe('useNodeDragToCanvas', () => {
|
||||
expect(isDragging.value).toBe(false)
|
||||
expect(draggedNode.value).toBeNull()
|
||||
})
|
||||
|
||||
it('should reset dragMode to click', () => {
|
||||
const { dragMode, startDrag, cancelDrag } = useNodeDragToCanvas()
|
||||
|
||||
startDrag(mockNodeDef, 'native')
|
||||
expect(dragMode.value).toBe('native')
|
||||
|
||||
cancelDrag()
|
||||
|
||||
expect(dragMode.value).toBe('click')
|
||||
})
|
||||
})
|
||||
|
||||
describe('setupGlobalListeners', () => {
|
||||
it('should add event listeners to document', () => {
|
||||
describe('drag listener lifecycle', () => {
|
||||
it('should attach document listeners on startDrag', () => {
|
||||
const addEventListenerSpy = vi.spyOn(document, 'addEventListener')
|
||||
const { setupGlobalListeners } = useNodeDragToCanvas()
|
||||
const { startDrag } = useNodeDragToCanvas()
|
||||
|
||||
setupGlobalListeners()
|
||||
startDrag(mockNodeDef)
|
||||
|
||||
expect(addEventListenerSpy).toHaveBeenCalledWith(
|
||||
'pointermove',
|
||||
expect.any(Function)
|
||||
)
|
||||
expect(addEventListenerSpy).toHaveBeenCalledWith(
|
||||
'pointerdown',
|
||||
expect.any(Function),
|
||||
@@ -142,35 +119,53 @@ describe('useNodeDragToCanvas', () => {
|
||||
)
|
||||
})
|
||||
|
||||
it('should only setup listeners once', () => {
|
||||
it('should not attach drag listeners until a drag starts', () => {
|
||||
const addEventListenerSpy = vi.spyOn(document, 'addEventListener')
|
||||
const { setupGlobalListeners } = useNodeDragToCanvas()
|
||||
useNodeDragToCanvas()
|
||||
|
||||
setupGlobalListeners()
|
||||
expect(addEventListenerSpy).not.toHaveBeenCalledWith(
|
||||
'pointerup',
|
||||
expect.any(Function),
|
||||
true
|
||||
)
|
||||
expect(addEventListenerSpy).not.toHaveBeenCalledWith(
|
||||
'keydown',
|
||||
expect.any(Function)
|
||||
)
|
||||
})
|
||||
|
||||
it('should detach document listeners on cancelDrag', () => {
|
||||
const removeEventListenerSpy = vi.spyOn(document, 'removeEventListener')
|
||||
const { startDrag, cancelDrag } = useNodeDragToCanvas()
|
||||
|
||||
startDrag(mockNodeDef)
|
||||
cancelDrag()
|
||||
|
||||
expect(removeEventListenerSpy).toHaveBeenCalledWith(
|
||||
'pointerdown',
|
||||
expect.any(Function),
|
||||
true
|
||||
)
|
||||
expect(removeEventListenerSpy).toHaveBeenCalledWith(
|
||||
'pointerup',
|
||||
expect.any(Function),
|
||||
true
|
||||
)
|
||||
})
|
||||
|
||||
it('should only attach listeners once across re-arms', () => {
|
||||
const addEventListenerSpy = vi.spyOn(document, 'addEventListener')
|
||||
const { startDrag } = useNodeDragToCanvas()
|
||||
|
||||
startDrag(mockNodeDef)
|
||||
const callCount = addEventListenerSpy.mock.calls.length
|
||||
|
||||
setupGlobalListeners()
|
||||
startDrag(mockNodeDef)
|
||||
|
||||
expect(addEventListenerSpy.mock.calls.length).toBe(callCount)
|
||||
})
|
||||
})
|
||||
|
||||
describe('cursorPosition', () => {
|
||||
it('should update on pointermove', () => {
|
||||
const { cursorPosition, setupGlobalListeners } = useNodeDragToCanvas()
|
||||
|
||||
setupGlobalListeners()
|
||||
|
||||
const pointerEvent = new PointerEvent('pointermove', {
|
||||
clientX: 100,
|
||||
clientY: 200
|
||||
})
|
||||
document.dispatchEvent(pointerEvent)
|
||||
|
||||
expect(cursorPosition.value).toEqual({ x: 100, y: 200 })
|
||||
})
|
||||
})
|
||||
|
||||
describe('endDrag behavior', () => {
|
||||
it('should add node when pointer is over canvas', () => {
|
||||
mockCanvas.canvas.getBoundingClientRect.mockReturnValue({
|
||||
@@ -181,9 +176,7 @@ describe('useNodeDragToCanvas', () => {
|
||||
})
|
||||
mockConvertEventToCanvasOffset.mockReturnValue([150, 150])
|
||||
|
||||
const { startDrag, setupGlobalListeners } = useNodeDragToCanvas()
|
||||
|
||||
setupGlobalListeners()
|
||||
const { startDrag } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef)
|
||||
|
||||
const pointerEvent = new PointerEvent('pointerup', {
|
||||
@@ -206,10 +199,7 @@ describe('useNodeDragToCanvas', () => {
|
||||
bottom: 500
|
||||
})
|
||||
|
||||
const { startDrag, setupGlobalListeners, isDragging } =
|
||||
useNodeDragToCanvas()
|
||||
|
||||
setupGlobalListeners()
|
||||
const { startDrag, isDragging } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef)
|
||||
|
||||
const pointerEvent = new PointerEvent('pointerup', {
|
||||
@@ -224,10 +214,7 @@ describe('useNodeDragToCanvas', () => {
|
||||
})
|
||||
|
||||
it('should cancel drag on Escape key', () => {
|
||||
const { startDrag, setupGlobalListeners, isDragging } =
|
||||
useNodeDragToCanvas()
|
||||
|
||||
setupGlobalListeners()
|
||||
const { startDrag, isDragging } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef)
|
||||
|
||||
expect(isDragging.value).toBe(true)
|
||||
@@ -239,10 +226,7 @@ describe('useNodeDragToCanvas', () => {
|
||||
})
|
||||
|
||||
it('should not cancel drag on other keys', () => {
|
||||
const { startDrag, setupGlobalListeners, isDragging } =
|
||||
useNodeDragToCanvas()
|
||||
|
||||
setupGlobalListeners()
|
||||
const { startDrag, isDragging } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef)
|
||||
|
||||
const keyEvent = new KeyboardEvent('keydown', { key: 'Enter' })
|
||||
@@ -262,8 +246,7 @@ describe('useNodeDragToCanvas', () => {
|
||||
const placedNode = { id: 1 }
|
||||
mockAddNodeOnGraph.mockReturnValue(placedNode)
|
||||
|
||||
const { startDrag, setupGlobalListeners } = useNodeDragToCanvas()
|
||||
setupGlobalListeners()
|
||||
const { startDrag } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef)
|
||||
|
||||
document.dispatchEvent(
|
||||
@@ -277,6 +260,97 @@ describe('useNodeDragToCanvas', () => {
|
||||
expect(mockSelectItems).toHaveBeenCalledWith([placedNode])
|
||||
})
|
||||
|
||||
it('should apply the requested widget values to the placed node', () => {
|
||||
mockCanvas.canvas.getBoundingClientRect.mockReturnValue({
|
||||
left: 0,
|
||||
right: 500,
|
||||
top: 0,
|
||||
bottom: 500
|
||||
})
|
||||
mockConvertEventToCanvasOffset.mockReturnValue([150, 150])
|
||||
const widget = { name: 'ckpt_name', value: '' }
|
||||
mockAddNodeOnGraph.mockReturnValue({ id: 1, widgets: [widget] })
|
||||
|
||||
const { startDrag } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef, {
|
||||
widgetValues: { ckpt_name: 'model.safetensors' }
|
||||
})
|
||||
|
||||
document.dispatchEvent(
|
||||
new PointerEvent('pointerup', {
|
||||
clientX: 250,
|
||||
clientY: 250,
|
||||
bubbles: true
|
||||
})
|
||||
)
|
||||
|
||||
expect(widget.value).toBe('model.safetensors')
|
||||
})
|
||||
|
||||
it('should still place the node when a requested widget is missing', () => {
|
||||
mockCanvas.canvas.getBoundingClientRect.mockReturnValue({
|
||||
left: 0,
|
||||
right: 500,
|
||||
top: 0,
|
||||
bottom: 500
|
||||
})
|
||||
mockConvertEventToCanvasOffset.mockReturnValue([150, 150])
|
||||
const placedNode = { id: 1, widgets: [] }
|
||||
mockAddNodeOnGraph.mockReturnValue(placedNode)
|
||||
const consoleErrorSpy = vi
|
||||
.spyOn(console, 'error')
|
||||
.mockImplementation(() => {})
|
||||
|
||||
const { startDrag } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef, {
|
||||
widgetValues: { ckpt_name: 'model.safetensors' }
|
||||
})
|
||||
|
||||
document.dispatchEvent(
|
||||
new PointerEvent('pointerup', {
|
||||
clientX: 250,
|
||||
clientY: 250,
|
||||
bubbles: true
|
||||
})
|
||||
)
|
||||
|
||||
expect(mockSelectItems).toHaveBeenCalledWith([placedNode])
|
||||
expect(mockToastAdd).not.toHaveBeenCalled()
|
||||
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining('ckpt_name')
|
||||
)
|
||||
})
|
||||
|
||||
it('should show an error toast when the graph fails to add the node', () => {
|
||||
mockCanvas.canvas.getBoundingClientRect.mockReturnValue({
|
||||
left: 0,
|
||||
right: 500,
|
||||
top: 0,
|
||||
bottom: 500
|
||||
})
|
||||
mockConvertEventToCanvasOffset.mockReturnValue([150, 150])
|
||||
mockAddNodeOnGraph.mockReturnValue(null)
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
|
||||
const { startDrag } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef)
|
||||
|
||||
document.dispatchEvent(
|
||||
new PointerEvent('pointerup', {
|
||||
clientX: 250,
|
||||
clientY: 250,
|
||||
bubbles: true
|
||||
})
|
||||
)
|
||||
|
||||
expect(mockToastAdd).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
severity: 'error',
|
||||
detail: 'assetBrowser.failedToCreateNode'
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('should not call selectItems when graph returns no node', () => {
|
||||
mockCanvas.canvas.getBoundingClientRect.mockReturnValue({
|
||||
left: 0,
|
||||
@@ -286,9 +360,9 @@ describe('useNodeDragToCanvas', () => {
|
||||
})
|
||||
mockConvertEventToCanvasOffset.mockReturnValue([150, 150])
|
||||
mockAddNodeOnGraph.mockReturnValue(null)
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
|
||||
const { startDrag, setupGlobalListeners } = useNodeDragToCanvas()
|
||||
setupGlobalListeners()
|
||||
const { startDrag } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef)
|
||||
|
||||
document.dispatchEvent(
|
||||
@@ -311,11 +385,8 @@ describe('useNodeDragToCanvas', () => {
|
||||
})
|
||||
mockConvertEventToCanvasOffset.mockReturnValue([150, 150])
|
||||
|
||||
const { startDrag, setupGlobalListeners, isDragging } =
|
||||
useNodeDragToCanvas()
|
||||
|
||||
setupGlobalListeners()
|
||||
startDrag(mockNodeDef, 'native')
|
||||
const { startDrag, isDragging } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef, { mode: 'native' })
|
||||
|
||||
const pointerEvent = new PointerEvent('pointerup', {
|
||||
clientX: 250,
|
||||
@@ -341,7 +412,7 @@ describe('useNodeDragToCanvas', () => {
|
||||
|
||||
const { startDrag, handleNativeDrop } = useNodeDragToCanvas()
|
||||
|
||||
startDrag(mockNodeDef, 'native')
|
||||
startDrag(mockNodeDef, { mode: 'native' })
|
||||
handleNativeDrop(250, 250)
|
||||
|
||||
expect(mockAddNodeOnGraph).toHaveBeenCalledWith(mockNodeDef, {
|
||||
@@ -359,7 +430,7 @@ describe('useNodeDragToCanvas', () => {
|
||||
|
||||
const { startDrag, handleNativeDrop, isDragging } = useNodeDragToCanvas()
|
||||
|
||||
startDrag(mockNodeDef, 'native')
|
||||
startDrag(mockNodeDef, { mode: 'native' })
|
||||
handleNativeDrop(600, 250)
|
||||
|
||||
expect(mockAddNodeOnGraph).not.toHaveBeenCalled()
|
||||
@@ -377,7 +448,7 @@ describe('useNodeDragToCanvas', () => {
|
||||
|
||||
const { startDrag, handleNativeDrop } = useNodeDragToCanvas()
|
||||
|
||||
startDrag(mockNodeDef, 'click')
|
||||
startDrag(mockNodeDef)
|
||||
handleNativeDrop(250, 250)
|
||||
|
||||
expect(mockAddNodeOnGraph).not.toHaveBeenCalled()
|
||||
@@ -392,14 +463,12 @@ describe('useNodeDragToCanvas', () => {
|
||||
})
|
||||
mockConvertEventToCanvasOffset.mockReturnValue([200, 200])
|
||||
|
||||
const { startDrag, handleNativeDrop, isDragging, dragMode } =
|
||||
useNodeDragToCanvas()
|
||||
const { startDrag, handleNativeDrop, isDragging } = useNodeDragToCanvas()
|
||||
|
||||
startDrag(mockNodeDef, 'native')
|
||||
startDrag(mockNodeDef, { mode: 'native' })
|
||||
handleNativeDrop(250, 250)
|
||||
|
||||
expect(isDragging.value).toBe(false)
|
||||
expect(dragMode.value).toBe('click')
|
||||
})
|
||||
})
|
||||
|
||||
@@ -426,31 +495,29 @@ describe('useNodeDragToCanvas', () => {
|
||||
})
|
||||
|
||||
it('should stop propagation when in click-drag mode over canvas', () => {
|
||||
const { startDrag, setupGlobalListeners } = useNodeDragToCanvas()
|
||||
setupGlobalListeners()
|
||||
const { startDrag } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef)
|
||||
|
||||
expect(dispatchPointerDown(250, 250)).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should not stop propagation when not dragging', () => {
|
||||
const { setupGlobalListeners } = useNodeDragToCanvas()
|
||||
setupGlobalListeners()
|
||||
it('should not stop propagation once the drag is cancelled', () => {
|
||||
const { startDrag, cancelDrag } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef)
|
||||
cancelDrag()
|
||||
|
||||
expect(dispatchPointerDown(250, 250)).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should not stop propagation in native drag mode', () => {
|
||||
const { startDrag, setupGlobalListeners } = useNodeDragToCanvas()
|
||||
setupGlobalListeners()
|
||||
startDrag(mockNodeDef, 'native')
|
||||
const { startDrag } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef, { mode: 'native' })
|
||||
|
||||
expect(dispatchPointerDown(250, 250)).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should not stop propagation when pointer is outside canvas', () => {
|
||||
const { startDrag, setupGlobalListeners } = useNodeDragToCanvas()
|
||||
setupGlobalListeners()
|
||||
const { startDrag } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef)
|
||||
|
||||
expect(dispatchPointerDown(600, 250)).not.toHaveBeenCalled()
|
||||
@@ -477,10 +544,8 @@ describe('useNodeDragToCanvas', () => {
|
||||
}
|
||||
|
||||
it('should prefer tracked drag position over dragend coordinates', () => {
|
||||
const { startDrag, setupGlobalListeners, handleNativeDrop } =
|
||||
useNodeDragToCanvas()
|
||||
setupGlobalListeners()
|
||||
startDrag(mockNodeDef, 'native')
|
||||
const { startDrag, handleNativeDrop } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef, { mode: 'native' })
|
||||
|
||||
fireDrag(250, 250)
|
||||
// dragend supplies a bad position (the Firefox bug); the tracked one
|
||||
@@ -494,10 +559,8 @@ describe('useNodeDragToCanvas', () => {
|
||||
})
|
||||
|
||||
it('should ignore drag events with (0, 0)', () => {
|
||||
const { startDrag, setupGlobalListeners, handleNativeDrop } =
|
||||
useNodeDragToCanvas()
|
||||
setupGlobalListeners()
|
||||
startDrag(mockNodeDef, 'native')
|
||||
const { startDrag, handleNativeDrop } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef, { mode: 'native' })
|
||||
|
||||
fireDrag(250, 250)
|
||||
fireDrag(0, 0)
|
||||
@@ -510,10 +573,8 @@ describe('useNodeDragToCanvas', () => {
|
||||
})
|
||||
|
||||
it('should fall back to dragend coordinates when no drag fired', () => {
|
||||
const { startDrag, setupGlobalListeners, handleNativeDrop } =
|
||||
useNodeDragToCanvas()
|
||||
setupGlobalListeners()
|
||||
startDrag(mockNodeDef, 'native')
|
||||
const { startDrag, handleNativeDrop } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef, { mode: 'native' })
|
||||
|
||||
handleNativeDrop(250, 250)
|
||||
|
||||
@@ -523,32 +584,14 @@ describe('useNodeDragToCanvas', () => {
|
||||
})
|
||||
})
|
||||
|
||||
it('should ignore dragover events fired before startDrag', () => {
|
||||
const { startDrag, setupGlobalListeners, handleNativeDrop } =
|
||||
useNodeDragToCanvas()
|
||||
setupGlobalListeners()
|
||||
|
||||
fireDrag(250, 250)
|
||||
startDrag(mockNodeDef, 'native')
|
||||
handleNativeDrop(300, 300)
|
||||
|
||||
expect(mockConvertEventToCanvasOffset).toHaveBeenCalledWith({
|
||||
clientX: 300,
|
||||
clientY: 300
|
||||
})
|
||||
})
|
||||
|
||||
it('should clear tracked position between drags', () => {
|
||||
const { startDrag, setupGlobalListeners, handleNativeDrop } =
|
||||
useNodeDragToCanvas()
|
||||
setupGlobalListeners()
|
||||
|
||||
startDrag(mockNodeDef, 'native')
|
||||
const { startDrag, handleNativeDrop } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef, { mode: 'native' })
|
||||
fireDrag(250, 250)
|
||||
handleNativeDrop(1505, 102)
|
||||
|
||||
// Second drag - no drag events, so we should fall back to args.
|
||||
startDrag(mockNodeDef, 'native')
|
||||
startDrag(mockNodeDef, { mode: 'native' })
|
||||
handleNativeDrop(300, 300)
|
||||
|
||||
expect(mockConvertEventToCanvasOffset).toHaveBeenLastCalledWith({
|
||||
|
||||
@@ -1,23 +1,29 @@
|
||||
import { ref, shallowRef } from 'vue'
|
||||
|
||||
import { t } from '@/i18n'
|
||||
import type { LGraphNode } from '@/lib/litegraph/src/litegraph'
|
||||
import { withNodeAddSource } from '@/platform/telemetry/nodeAdded/nodeAddSource'
|
||||
import { useToastStore } from '@/platform/updates/common/toastStore'
|
||||
import { useCanvasStore } from '@/renderer/core/canvas/canvasStore'
|
||||
import { useLitegraphService } from '@/services/litegraphService'
|
||||
import type { ComfyNodeDefImpl } from '@/stores/nodeDefStore'
|
||||
|
||||
type DragMode = 'click' | 'native'
|
||||
type WidgetValues = Record<string, string>
|
||||
type Position = { x: number; y: number }
|
||||
|
||||
interface StartDragOptions {
|
||||
mode?: DragMode
|
||||
widgetValues?: WidgetValues
|
||||
}
|
||||
|
||||
const isDragging = ref(false)
|
||||
const draggedNode = shallowRef<ComfyNodeDefImpl | null>(null)
|
||||
const cursorPosition = ref({ x: 0, y: 0 })
|
||||
const dragMode = ref<DragMode>('click')
|
||||
const lastNativeDragPosition = shallowRef<{ x: number; y: number }>()
|
||||
const lastNativeDragPosition = shallowRef<Position>()
|
||||
const pendingWidgetValues = shallowRef<WidgetValues>()
|
||||
let listenersSetup = false
|
||||
|
||||
function updatePosition(e: PointerEvent) {
|
||||
cursorPosition.value = { x: e.clientX, y: e.clientY }
|
||||
}
|
||||
|
||||
// Firefox dragend can report stale clientX/Y and `drag` can fire with
|
||||
// (0, 0). dragover on the target reliably reports real client coords.
|
||||
// https://bugzilla.mozilla.org/show_bug.cgi?id=1773886
|
||||
@@ -27,11 +33,15 @@ function trackNativeDragPosition(e: DragEvent) {
|
||||
lastNativeDragPosition.value = { x: e.clientX, y: e.clientY }
|
||||
}
|
||||
|
||||
function cancelDrag() {
|
||||
isDragging.value = false
|
||||
draggedNode.value = null
|
||||
dragMode.value = 'click'
|
||||
lastNativeDragPosition.value = undefined
|
||||
function applyWidgetValues(node: LGraphNode, values: WidgetValues) {
|
||||
for (const [name, value] of Object.entries(values)) {
|
||||
const widget = node.widgets?.find((w) => w.name === name)
|
||||
if (!widget) {
|
||||
console.error(`Widget ${name} not found on node ${node.type}`)
|
||||
continue
|
||||
}
|
||||
widget.value = value
|
||||
}
|
||||
}
|
||||
|
||||
function isOverCanvas(clientX: number, clientY: number): boolean {
|
||||
@@ -62,7 +72,19 @@ function addNodeAtPosition(clientX: number, clientY: number): boolean {
|
||||
const node = withNodeAddSource('sidebar_drag', () =>
|
||||
useLitegraphService().addNodeOnGraph(nodeDef, { pos })
|
||||
)
|
||||
if (node) canvas.selectItems([node])
|
||||
if (!node) {
|
||||
console.error(`Failed to add node to graph: ${nodeDef.name}`)
|
||||
useToastStore().add({
|
||||
severity: 'error',
|
||||
summary: t('g.error'),
|
||||
detail: t('assetBrowser.failedToCreateNode')
|
||||
})
|
||||
return true
|
||||
}
|
||||
|
||||
if (pendingWidgetValues.value)
|
||||
applyWidgetValues(node, pendingWidgetValues.value)
|
||||
canvas.selectItems([node])
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -92,7 +114,6 @@ function setupGlobalListeners() {
|
||||
if (listenersSetup) return
|
||||
listenersSetup = true
|
||||
|
||||
document.addEventListener('pointermove', updatePosition)
|
||||
document.addEventListener('pointerdown', blockCommitPointerDown, true)
|
||||
document.addEventListener('pointerup', endDrag, true)
|
||||
document.addEventListener('keydown', handleKeydown)
|
||||
@@ -103,22 +124,31 @@ function cleanupGlobalListeners() {
|
||||
if (!listenersSetup) return
|
||||
listenersSetup = false
|
||||
|
||||
document.removeEventListener('pointermove', updatePosition)
|
||||
document.removeEventListener('pointerdown', blockCommitPointerDown, true)
|
||||
document.removeEventListener('pointerup', endDrag, true)
|
||||
document.removeEventListener('keydown', handleKeydown)
|
||||
document.removeEventListener('dragover', trackNativeDragPosition)
|
||||
}
|
||||
|
||||
if (isDragging.value && dragMode.value === 'click') {
|
||||
cancelDrag()
|
||||
}
|
||||
function cancelDrag() {
|
||||
isDragging.value = false
|
||||
draggedNode.value = null
|
||||
dragMode.value = 'click'
|
||||
lastNativeDragPosition.value = undefined
|
||||
pendingWidgetValues.value = undefined
|
||||
cleanupGlobalListeners()
|
||||
}
|
||||
|
||||
export function useNodeDragToCanvas() {
|
||||
function startDrag(nodeDef: ComfyNodeDefImpl, mode: DragMode = 'click') {
|
||||
function startDrag(
|
||||
nodeDef: ComfyNodeDefImpl,
|
||||
{ mode = 'click', widgetValues }: StartDragOptions = {}
|
||||
) {
|
||||
isDragging.value = true
|
||||
draggedNode.value = nodeDef
|
||||
dragMode.value = mode
|
||||
pendingWidgetValues.value = widgetValues
|
||||
setupGlobalListeners()
|
||||
}
|
||||
|
||||
function handleNativeDrop(clientX: number, clientY: number) {
|
||||
@@ -134,12 +164,9 @@ export function useNodeDragToCanvas() {
|
||||
return {
|
||||
isDragging,
|
||||
draggedNode,
|
||||
cursorPosition,
|
||||
dragMode,
|
||||
pendingWidgetValues,
|
||||
startDrag,
|
||||
cancelDrag,
|
||||
handleNativeDrop,
|
||||
setupGlobalListeners,
|
||||
cleanupGlobalListeners
|
||||
handleNativeDrop
|
||||
}
|
||||
}
|
||||
|
||||
@@ -124,7 +124,9 @@ describe('useNodePreviewAndDrag', () => {
|
||||
|
||||
expect(result.isDragging.value).toBe(true)
|
||||
expect(result.isHovered.value).toBe(false)
|
||||
expect(mockStartDrag).toHaveBeenCalledWith(mockNodeDef, 'native')
|
||||
expect(mockStartDrag).toHaveBeenCalledWith(mockNodeDef, {
|
||||
mode: 'native'
|
||||
})
|
||||
expect(mockDataTransfer.effectAllowed).toBe('copy')
|
||||
expect(mockDataTransfer.setData).toHaveBeenCalledWith(
|
||||
'application/x-comfy-node',
|
||||
|
||||
@@ -112,7 +112,7 @@ export function useNodePreviewAndDrag(
|
||||
isDragging.value = true
|
||||
isHovered.value = false
|
||||
|
||||
startDrag(nodeDef.value, 'native')
|
||||
startDrag(nodeDef.value, { mode: 'native' })
|
||||
|
||||
if (e.dataTransfer) {
|
||||
e.dataTransfer.effectAllowed = 'copy'
|
||||
|
||||
@@ -5,11 +5,13 @@ import { ref } from 'vue'
|
||||
import { useCoreCommands } from '@/composables/useCoreCommands'
|
||||
import { useExternalLink } from '@/composables/useExternalLink'
|
||||
import type { LGraphNode } from '@/lib/litegraph/src/litegraph'
|
||||
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
|
||||
import { useSettingStore } from '@/platform/settings/settingStore'
|
||||
import { api } from '@/scripts/api'
|
||||
import { app } from '@/scripts/app'
|
||||
import type * as ModelStoreModule from '@/stores/modelStore'
|
||||
import { createMockLGraphNode } from '@/utils/__tests__/litegraphTestUtils'
|
||||
import { fromPartial } from '@total-typescript/shoehorn'
|
||||
|
||||
// Mock vue-i18n for useExternalLink
|
||||
const mockLocale = ref('en')
|
||||
@@ -135,6 +137,23 @@ vi.mock('@/stores/toastStore', () => ({
|
||||
useToastStore: vi.fn(() => ({}))
|
||||
}))
|
||||
|
||||
const mockToastAdd = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/platform/updates/common/toastStore', () => ({
|
||||
useToastStore: vi.fn(() => ({ add: mockToastAdd }))
|
||||
}))
|
||||
|
||||
const mockAssetBrowse = vi.hoisted(() =>
|
||||
vi.fn<(options: { onAssetSelected?: (asset: AssetItem) => void }) => void>()
|
||||
)
|
||||
vi.mock('@/platform/assets/composables/useAssetBrowserDialog', () => ({
|
||||
useAssetBrowserDialog: vi.fn(() => ({ browse: mockAssetBrowse }))
|
||||
}))
|
||||
|
||||
const mockStartModelNodeDrag = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/composables/node/startModelNodeDragFromAsset', () => ({
|
||||
startModelNodeDragFromAsset: mockStartModelNodeDrag
|
||||
}))
|
||||
|
||||
const mockChangeTracker = vi.hoisted(() => ({
|
||||
captureCanvasState: vi.fn()
|
||||
}))
|
||||
@@ -618,4 +637,44 @@ describe('useCoreCommands', () => {
|
||||
expect(mockShowAbout).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('BrowseModelAssets command', () => {
|
||||
const asset = fromPartial<AssetItem>({ id: 'asset-1' })
|
||||
|
||||
async function selectAssetFromBrowser() {
|
||||
vi.mocked(useSettingStore).mockReturnValue(createMockSettingStore(true))
|
||||
|
||||
const command = useCoreCommands().find(
|
||||
(cmd) => cmd.id === 'Comfy.BrowseModelAssets'
|
||||
)!
|
||||
await command.function()
|
||||
|
||||
const { onAssetSelected } = mockAssetBrowse.mock.calls[0][0]
|
||||
onAssetSelected?.(asset)
|
||||
}
|
||||
|
||||
it('starts a model node drag for the selected asset', async () => {
|
||||
mockStartModelNodeDrag.mockReturnValue(undefined)
|
||||
|
||||
await selectAssetFromBrowser()
|
||||
|
||||
expect(mockStartModelNodeDrag).toHaveBeenCalledWith(asset)
|
||||
expect(mockToastAdd).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('shows an error toast when the asset cannot start a drag', async () => {
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
mockStartModelNodeDrag.mockReturnValue({
|
||||
code: 'NO_PROVIDER',
|
||||
message: 'No node provider registered',
|
||||
assetId: 'asset-1'
|
||||
})
|
||||
|
||||
await selectAssetFromBrowser()
|
||||
|
||||
expect(mockToastAdd).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ severity: 'error' })
|
||||
)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -2,6 +2,7 @@ import { useCurrentUser } from '@/composables/auth/useCurrentUser'
|
||||
import { useAuthActions } from '@/composables/auth/useAuthActions'
|
||||
import { useSelectedLiteGraphItems } from '@/composables/canvas/useSelectedLiteGraphItems'
|
||||
import { useSubgraphOperations } from '@/composables/graph/useSubgraphOperations'
|
||||
import { startModelNodeDragFromAsset } from '@/composables/node/startModelNodeDragFromAsset'
|
||||
import { useExternalLink } from '@/composables/useExternalLink'
|
||||
import { useModelSelectorDialog } from '@/composables/useModelSelectorDialog'
|
||||
import {
|
||||
@@ -20,7 +21,6 @@ import {
|
||||
import type { Point } from '@/lib/litegraph/src/litegraph'
|
||||
import { useBillingContext } from '@/composables/billing/useBillingContext'
|
||||
import { useAssetBrowserDialog } from '@/platform/assets/composables/useAssetBrowserDialog'
|
||||
import { createModelNodeFromAsset } from '@/platform/assets/utils/createModelNodeFromAsset'
|
||||
import { useSettingStore } from '@/platform/settings/settingStore'
|
||||
import { buildSupportUrl } from '@/platform/support/config'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
@@ -1305,14 +1305,14 @@ export function useCoreCommands(): ComfyCommand[] {
|
||||
assetType: 'models',
|
||||
title: t('sideToolbar.modelLibrary'),
|
||||
onAssetSelected: (asset) => {
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
if (!result.success) {
|
||||
const error = startModelNodeDragFromAsset(asset)
|
||||
if (error) {
|
||||
toastStore.add({
|
||||
severity: 'error',
|
||||
summary: t('g.error'),
|
||||
detail: t('assetBrowser.failedToCreateNode')
|
||||
})
|
||||
console.error('Node creation failed:', result.error)
|
||||
console.error('Node creation failed:', error)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@@ -22,7 +22,6 @@ import {
|
||||
LOAD3D_NONE_MODEL,
|
||||
SUPPORTED_EXTENSIONS_ACCEPT
|
||||
} from '@/extensions/core/load3d/constants'
|
||||
import { snapshotLoad3dState } from '@/extensions/core/load3d/load3dSerialize'
|
||||
import Load3dUtils from '@/extensions/core/load3d/Load3dUtils'
|
||||
import { t } from '@/i18n'
|
||||
import type { LGraphNode } from '@/lib/litegraph/src/LGraphNode'
|
||||
@@ -414,10 +413,16 @@ useExtensionService().registerExtension({
|
||||
if (cached) return cached
|
||||
}
|
||||
|
||||
const { camera_info, model_3d_info } = snapshotLoad3dState(
|
||||
node,
|
||||
currentLoad3d
|
||||
)
|
||||
const cameraConfig: CameraConfig = (node.properties[
|
||||
'Camera Config'
|
||||
] as CameraConfig | undefined) || {
|
||||
cameraType: currentLoad3d.getCurrentCameraType(),
|
||||
fov: currentLoad3d.cameraManager.perspectiveCamera.fov
|
||||
}
|
||||
cameraConfig.state = currentLoad3d.getCameraState()
|
||||
node.properties['Camera Config'] = cameraConfig
|
||||
|
||||
currentLoad3d.stopRecording()
|
||||
|
||||
const {
|
||||
scene: imageData,
|
||||
@@ -436,11 +441,16 @@ useExtensionService().registerExtension({
|
||||
|
||||
currentLoad3d.handleResize()
|
||||
|
||||
const modelInfo = currentLoad3d.getModelInfo()
|
||||
const model_3d_info: Model3DInfo = modelInfo ? [modelInfo] : []
|
||||
|
||||
const returnVal: Load3dCachedOutput = {
|
||||
image: `threed/${data.name} [temp]`,
|
||||
mask: `threed/${dataMask.name} [temp]`,
|
||||
normal: `threed/${dataNormal.name} [temp]`,
|
||||
camera_info,
|
||||
camera_info:
|
||||
(node.properties['Camera Config'] as CameraConfig | undefined)
|
||||
?.state || null,
|
||||
recording: '',
|
||||
model_3d_info
|
||||
}
|
||||
|
||||
@@ -23,7 +23,6 @@ const mtlLoaderStub = {
|
||||
const objLoaderStub = {
|
||||
setWorkerUrl: vi.fn(),
|
||||
setMaterials: vi.fn(),
|
||||
setBaseObject3d: vi.fn(),
|
||||
loadAsync: vi.fn<(url: string) => Promise<THREE.Object3D>>()
|
||||
}
|
||||
|
||||
@@ -59,7 +58,6 @@ vi.mock('wwobjloader2', () => ({
|
||||
OBJLoader2Parallel: class {
|
||||
setWorkerUrl = objLoaderStub.setWorkerUrl
|
||||
setMaterials = objLoaderStub.setMaterials
|
||||
setBaseObject3d = objLoaderStub.setBaseObject3d
|
||||
loadAsync = objLoaderStub.loadAsync
|
||||
},
|
||||
MtlObjBridge: {
|
||||
@@ -249,24 +247,6 @@ describe('MeshModelAdapter', () => {
|
||||
|
||||
expect(ctx.registerOriginalMaterial).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('resets baseObject3d on every load so meshes do not accumulate across calls', async () => {
|
||||
objLoaderStub.loadAsync.mockResolvedValue(makeFbxLikeGroup())
|
||||
|
||||
const adapter = new MeshModelAdapter()
|
||||
const ctx = makeContext('wireframe')
|
||||
await adapter.load(ctx, '/api/view/', 'first.obj')
|
||||
await adapter.load(ctx, '/api/view/', 'second.obj')
|
||||
|
||||
expect(objLoaderStub.setBaseObject3d).toHaveBeenCalledTimes(2)
|
||||
const bases = objLoaderStub.setBaseObject3d.mock.calls.map(
|
||||
([base]) => base
|
||||
)
|
||||
expect(bases[0]).toBeInstanceOf(THREE.Object3D)
|
||||
expect(bases[1]).toBeInstanceOf(THREE.Object3D)
|
||||
// Each call should hand the loader a fresh container, not the same one.
|
||||
expect(bases[0]).not.toBe(bases[1])
|
||||
})
|
||||
})
|
||||
|
||||
describe('GLTF loader path', () => {
|
||||
|
||||
@@ -102,8 +102,6 @@ export class MeshModelAdapter implements ModelAdapter {
|
||||
path: string,
|
||||
filename: string
|
||||
): Promise<THREE.Object3D> {
|
||||
this.objLoader.setBaseObject3d(new THREE.Object3D())
|
||||
|
||||
if (ctx.materialMode === 'original') {
|
||||
try {
|
||||
this.mtlLoader.setPath(path)
|
||||
|
||||
@@ -1,87 +0,0 @@
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type Load3d from '@/extensions/core/load3d/Load3d'
|
||||
import { snapshotLoad3dState } from '@/extensions/core/load3d/load3dSerialize'
|
||||
import type { CameraState } from '@/extensions/core/load3d/interfaces'
|
||||
import type { LGraphNode } from '@/lib/litegraph/src/LGraphNode'
|
||||
|
||||
function makeNode(props: Record<string, unknown> = {}): LGraphNode {
|
||||
return { properties: { ...props } } as unknown as LGraphNode
|
||||
}
|
||||
|
||||
const baseCameraState: CameraState = {
|
||||
position: { x: 1, y: 2, z: 3 },
|
||||
target: { x: 0, y: 0, z: 0 },
|
||||
zoom: 1,
|
||||
cameraType: 'perspective'
|
||||
} as unknown as CameraState
|
||||
|
||||
function makeLoad3d({
|
||||
cameraType = 'perspective',
|
||||
fov = 35,
|
||||
modelInfo = { transform: { position: [0, 0, 0] } } as unknown
|
||||
}: {
|
||||
cameraType?: string
|
||||
fov?: number
|
||||
modelInfo?: unknown
|
||||
} = {}) {
|
||||
return {
|
||||
getCurrentCameraType: vi.fn(() => cameraType),
|
||||
cameraManager: { perspectiveCamera: { fov } },
|
||||
getCameraState: vi.fn(() => baseCameraState),
|
||||
stopRecording: vi.fn(),
|
||||
getModelInfo: vi.fn(() => modelInfo)
|
||||
} as unknown as Load3d
|
||||
}
|
||||
|
||||
describe('snapshotLoad3dState', () => {
|
||||
it('returns only camera_info and model_3d_info', () => {
|
||||
const result = snapshotLoad3dState(makeNode(), makeLoad3d())
|
||||
expect(Object.keys(result).sort()).toEqual(['camera_info', 'model_3d_info'])
|
||||
})
|
||||
|
||||
it('writes the camera state into properties["Camera Config"]', () => {
|
||||
const node = makeNode()
|
||||
snapshotLoad3dState(node, makeLoad3d({ fov: 42 }))
|
||||
const cfg = node.properties['Camera Config'] as Record<string, unknown>
|
||||
expect(cfg).toMatchObject({
|
||||
cameraType: 'perspective',
|
||||
fov: 42,
|
||||
state: baseCameraState
|
||||
})
|
||||
})
|
||||
|
||||
it('preserves an existing Camera Config object instead of replacing it', () => {
|
||||
const existing = { cameraType: 'orthographic', fov: 99 }
|
||||
const node = makeNode({ 'Camera Config': existing })
|
||||
snapshotLoad3dState(node, makeLoad3d())
|
||||
// Same object reference (mutated in place), with state attached.
|
||||
expect(node.properties['Camera Config']).toBe(existing)
|
||||
expect(
|
||||
(node.properties['Camera Config'] as Record<string, unknown>).state
|
||||
).toBe(baseCameraState)
|
||||
})
|
||||
|
||||
it('stops in-progress recording as a side effect', () => {
|
||||
const load3d = makeLoad3d()
|
||||
snapshotLoad3dState(makeNode(), load3d)
|
||||
expect(load3d.stopRecording).toHaveBeenCalledOnce()
|
||||
})
|
||||
|
||||
it('returns model_3d_info as a single-element list when a model is loaded', () => {
|
||||
const info = { transform: { position: [1, 2, 3] } }
|
||||
const result = snapshotLoad3dState(
|
||||
makeNode(),
|
||||
makeLoad3d({ modelInfo: info })
|
||||
)
|
||||
expect(result.model_3d_info).toEqual([info])
|
||||
})
|
||||
|
||||
it('returns an empty model_3d_info list when no model is loaded', () => {
|
||||
const result = snapshotLoad3dState(
|
||||
makeNode(),
|
||||
makeLoad3d({ modelInfo: null })
|
||||
)
|
||||
expect(result.model_3d_info).toEqual([])
|
||||
})
|
||||
})
|
||||
@@ -1,36 +0,0 @@
|
||||
import type Load3d from '@/extensions/core/load3d/Load3d'
|
||||
import type {
|
||||
CameraConfig,
|
||||
CameraState,
|
||||
Model3DInfo
|
||||
} from '@/extensions/core/load3d/interfaces'
|
||||
import type { LGraphNode } from '@/lib/litegraph/src/LGraphNode'
|
||||
|
||||
export type Load3dSerializedBase = {
|
||||
camera_info: CameraState | null
|
||||
model_3d_info: Model3DInfo
|
||||
}
|
||||
|
||||
export function snapshotLoad3dState(
|
||||
node: LGraphNode,
|
||||
load3d: Load3d
|
||||
): Load3dSerializedBase {
|
||||
const cameraConfig: CameraConfig = (node.properties['Camera Config'] as
|
||||
| CameraConfig
|
||||
| undefined) || {
|
||||
cameraType: load3d.getCurrentCameraType(),
|
||||
fov: load3d.cameraManager.perspectiveCamera.fov
|
||||
}
|
||||
cameraConfig.state = load3d.getCameraState()
|
||||
node.properties['Camera Config'] = cameraConfig
|
||||
|
||||
load3d.stopRecording()
|
||||
|
||||
const modelInfo = load3d.getModelInfo()
|
||||
const model_3d_info: Model3DInfo = modelInfo ? [modelInfo] : []
|
||||
|
||||
return {
|
||||
camera_info: cameraConfig.state ?? null,
|
||||
model_3d_info
|
||||
}
|
||||
}
|
||||
@@ -9,12 +9,7 @@ const LOAD3D_PREVIEW_NODES = new Set([
|
||||
'PreviewPointCloud'
|
||||
])
|
||||
|
||||
const LOAD3D_ALL_NODES = new Set([
|
||||
...LOAD3D_PREVIEW_NODES,
|
||||
'Load3D',
|
||||
'Load3DAdvanced',
|
||||
'SaveGLB'
|
||||
])
|
||||
const LOAD3D_ALL_NODES = new Set([...LOAD3D_PREVIEW_NODES, 'Load3D', 'SaveGLB'])
|
||||
|
||||
export const isLoad3dPreviewNode = (nodeType: string): boolean =>
|
||||
LOAD3D_PREVIEW_NODES.has(nodeType)
|
||||
|
||||
@@ -1,103 +0,0 @@
|
||||
import { nextTick } from 'vue'
|
||||
|
||||
import Load3DAdvanced from '@/components/load3d/Load3DAdvanced.vue'
|
||||
import { nodeToLoad3dMap, useLoad3d } from '@/composables/useLoad3d'
|
||||
import { createExportMenuItems } from '@/extensions/core/load3d/exportMenuHelper'
|
||||
import type { CameraConfig } from '@/extensions/core/load3d/interfaces'
|
||||
import Load3DConfiguration from '@/extensions/core/load3d/Load3DConfiguration'
|
||||
import { snapshotLoad3dState } from '@/extensions/core/load3d/load3dSerialize'
|
||||
import type { LGraphNode } from '@/lib/litegraph/src/LGraphNode'
|
||||
import type { IContextMenuValue } from '@/lib/litegraph/src/interfaces'
|
||||
import type { CustomInputSpec } from '@/schemas/nodeDef/nodeDefSchemaV2'
|
||||
import { ComponentWidgetImpl, addWidget } from '@/scripts/domWidget'
|
||||
import { useExtensionService } from '@/services/extensionService'
|
||||
import { useLoad3dService } from '@/services/load3dService'
|
||||
|
||||
const inputSpecLoad3DAdvanced: CustomInputSpec = {
|
||||
name: 'viewport_state',
|
||||
type: 'LOAD_3D_ADVANCED',
|
||||
isPreview: false
|
||||
}
|
||||
|
||||
useExtensionService().registerExtension({
|
||||
name: 'Comfy.Load3DAdvanced',
|
||||
|
||||
beforeRegisterNodeDef(_nodeType, nodeData) {
|
||||
if (nodeData.name !== 'Load3DAdvanced') return
|
||||
if (!nodeData.input?.required) return
|
||||
nodeData.input.required.viewport_state = ['LOAD_3D_ADVANCED', {}]
|
||||
},
|
||||
|
||||
getNodeMenuItems(node: LGraphNode): (IContextMenuValue | null)[] {
|
||||
if (node.constructor.comfyClass !== 'Load3DAdvanced') return []
|
||||
|
||||
const load3d = useLoad3dService().getLoad3d(node)
|
||||
if (!load3d) return []
|
||||
|
||||
return createExportMenuItems(load3d)
|
||||
},
|
||||
|
||||
getCustomWidgets() {
|
||||
return {
|
||||
LOAD_3D_ADVANCED(node) {
|
||||
const widget = new ComponentWidgetImpl({
|
||||
node,
|
||||
name: 'viewport_state',
|
||||
component: Load3DAdvanced,
|
||||
inputSpec: inputSpecLoad3DAdvanced,
|
||||
options: {}
|
||||
})
|
||||
|
||||
widget.type = 'load3DAdvanced'
|
||||
|
||||
addWidget(node, widget)
|
||||
|
||||
return { widget }
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
async nodeCreated(node: LGraphNode) {
|
||||
if (node.constructor.comfyClass !== 'Load3DAdvanced') return
|
||||
|
||||
const [oldWidth, oldHeight] = node.size
|
||||
node.setSize([Math.max(oldWidth, 300), Math.max(oldHeight, 600)])
|
||||
|
||||
await nextTick()
|
||||
|
||||
useLoad3d(node).onLoad3dReady((load3d) => {
|
||||
const modelWidget = node.widgets?.find((w) => w.name === 'model_file')
|
||||
const width = node.widgets?.find((w) => w.name === 'width')
|
||||
const height = node.widgets?.find((w) => w.name === 'height')
|
||||
if (!modelWidget || !width || !height) return
|
||||
|
||||
const cameraConfig = node.properties['Camera Config'] as
|
||||
| CameraConfig
|
||||
| undefined
|
||||
const cameraState = cameraConfig?.state
|
||||
|
||||
const config = new Load3DConfiguration(load3d, node.properties)
|
||||
config.configure({
|
||||
loadFolder: 'input',
|
||||
modelWidget,
|
||||
cameraState,
|
||||
width,
|
||||
height
|
||||
})
|
||||
})
|
||||
|
||||
useLoad3d(node).waitForLoad3d(() => {
|
||||
const sceneWidget = node.widgets?.find((w) => w.name === 'viewport_state')
|
||||
if (!sceneWidget) return
|
||||
|
||||
sceneWidget.serializeValue = async () => {
|
||||
const currentLoad3d = nodeToLoad3dMap.get(node)
|
||||
if (!currentLoad3d) {
|
||||
console.error('No load3d instance found for node')
|
||||
return null
|
||||
}
|
||||
return snapshotLoad3dState(node, currentLoad3d)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
@@ -37,7 +37,6 @@ async function loadLoad3dExtensions(): Promise<ComfyExtension[]> {
|
||||
// Import extensions - they self-register via useExtensionService()
|
||||
await Promise.all([
|
||||
import('./load3d'),
|
||||
import('./load3dAdvanced'),
|
||||
import('./load3dPreviewExtensions'),
|
||||
import('./saveMesh')
|
||||
])
|
||||
@@ -67,12 +66,6 @@ useExtensionService().registerExtension({
|
||||
modelFile[1].mesh_upload = true
|
||||
modelFile[1].upload_subfolder = '3d'
|
||||
}
|
||||
} else if (nodeData.name === 'Load3DAdvanced') {
|
||||
const modelFile = nodeData.input?.required?.model_file
|
||||
if (modelFile?.[1]) {
|
||||
modelFile[1].mesh_upload = true
|
||||
modelFile[1].upload_subfolder = ''
|
||||
}
|
||||
}
|
||||
|
||||
// Load the 3D extensions and replay their beforeRegisterNodeDef hooks,
|
||||
|
||||
@@ -3087,13 +3087,6 @@
|
||||
"loadingModels": "Loading {type}...",
|
||||
"maxFileSize": "Max file size: {size}",
|
||||
"maxFileSizeValue": "1 GB",
|
||||
"missingModelImportTypeLocked": "Locked to {type} for this missing model",
|
||||
"missingModelImportTypeMismatchAlreadyImported": "This file is already imported as {actual}.",
|
||||
"missingModelImportTypeMismatchNextAction": "Try importing a different {required} model that this node can use.",
|
||||
"missingModelImportTypeMismatchRequired": "This node requires {required}, so this import cannot resolve the missing model.",
|
||||
"missingModelImportTypeMismatchTitle": "This model cannot resolve the missing model.",
|
||||
"missingModelImportUnknownType": "another model type",
|
||||
"missingModelImportWillReplace": "This import will replace {model} in:",
|
||||
"modelAssociatedWithLink": "The model associated with the link you provided:",
|
||||
"modelName": "Model Name",
|
||||
"modelNamePlaceholder": "Enter a name for this model",
|
||||
@@ -3647,7 +3640,9 @@
|
||||
},
|
||||
"missingModels": {
|
||||
"urlPlaceholder": "Paste Model URL (Civitai or Hugging Face)",
|
||||
"readyToApply": "Ready to apply",
|
||||
"or": "OR",
|
||||
"useFromLibrary": "Use from Library",
|
||||
"usingFromLibrary": "Using from Library",
|
||||
"unsupportedUrl": "Only Civitai and Hugging Face URLs are supported.",
|
||||
"metadataFetchFailed": "Failed to retrieve metadata. Please check the link and try again.",
|
||||
"import": "Import",
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
import { render, screen } from '@testing-library/vue'
|
||||
import { createI18n } from 'vue-i18n'
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
import enMessages from '@/locales/en/main.json' with { type: 'json' }
|
||||
import type { UploadModelDialogContext } from '@/platform/assets/composables/useUploadModelWizard'
|
||||
|
||||
import UploadModelConfirmation from './UploadModelConfirmation.vue'
|
||||
|
||||
const i18n = createI18n({
|
||||
legacy: false,
|
||||
locale: 'en',
|
||||
messages: { en: enMessages },
|
||||
missingWarn: false,
|
||||
fallbackWarn: false,
|
||||
escapeParameter: true
|
||||
})
|
||||
|
||||
const SingleSelectStub = {
|
||||
name: 'SingleSelect',
|
||||
props: {
|
||||
disabled: Boolean,
|
||||
modelValue: String
|
||||
},
|
||||
template:
|
||||
'<button type="button" :disabled="disabled">{{ modelValue }}</button>'
|
||||
}
|
||||
|
||||
describe('UploadModelConfirmation', () => {
|
||||
it('shows missing-model replacement context and locks the model type', () => {
|
||||
const uploadContext: UploadModelDialogContext = {
|
||||
kind: 'missing-model-resolution',
|
||||
missingModelName: 'segm/person_yolov8m-seg.pt',
|
||||
requiredModelType: 'Ultralytics/bbox',
|
||||
replacementTargets: [
|
||||
{
|
||||
nodeId: '1',
|
||||
nodeLabel: 'Checkpoint Loader',
|
||||
widgetName: 'ckpt_name'
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
render(UploadModelConfirmation, {
|
||||
props: {
|
||||
modelValue: 'Ultralytics/bbox',
|
||||
metadata: {
|
||||
content_length: 100,
|
||||
final_url: 'https://civitai.com/models/123',
|
||||
filename: 'replacement.safetensors'
|
||||
},
|
||||
uploadContext,
|
||||
'onUpdate:modelValue': () => {}
|
||||
},
|
||||
global: {
|
||||
plugins: [i18n],
|
||||
stubs: {
|
||||
SingleSelect: SingleSelectStub
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
expect(screen.getByText('segm/person_yolov8m-seg.pt')).toBeInTheDocument()
|
||||
expect(screen.getByText('Checkpoint Loader')).toBeInTheDocument()
|
||||
expect(screen.getByText('- ckpt_name')).toBeInTheDocument()
|
||||
const modelTypeSelect = screen.getByRole('button', {
|
||||
name: 'Ultralytics/bbox'
|
||||
})
|
||||
|
||||
expect(modelTypeSelect).toBeDisabled()
|
||||
expect(
|
||||
screen.getByText((_content, element) => {
|
||||
return (
|
||||
element?.textContent ===
|
||||
'Locked to Ultralytics/bbox for this missing model'
|
||||
)
|
||||
})
|
||||
).toBeInTheDocument()
|
||||
expect(screen.queryByText(///)).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
@@ -22,50 +22,16 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div
|
||||
v-if="isMissingModelResolution"
|
||||
class="flex flex-col gap-2 rounded-lg bg-secondary-background px-4 py-3"
|
||||
>
|
||||
<i18n-t
|
||||
keypath="assetBrowser.missingModelImportWillReplace"
|
||||
tag="p"
|
||||
class="m-0 text-base-foreground"
|
||||
>
|
||||
<template #model>
|
||||
<span>{{ missingModelName }}</span>
|
||||
</template>
|
||||
</i18n-t>
|
||||
<ul class="m-0 list-none space-y-1 p-0">
|
||||
<li
|
||||
v-for="target in replacementTargets"
|
||||
:key="`${target.nodeId}:${target.widgetName}`"
|
||||
class="flex min-w-0 items-center gap-2"
|
||||
>
|
||||
<span class="min-w-0 truncate text-muted-foreground">
|
||||
{{ target.nodeLabel }}
|
||||
</span>
|
||||
<span class="shrink-0 text-muted-foreground">
|
||||
- {{ target.widgetName }}
|
||||
</span>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<!-- Model Type Selection -->
|
||||
<div class="flex flex-col gap-2">
|
||||
<div class="flex flex-col gap-1">
|
||||
<div class="flex items-center gap-2">
|
||||
<label>
|
||||
{{ $t('assetBrowser.modelTypeSelectorLabel') }}
|
||||
</label>
|
||||
<i
|
||||
aria-hidden="true"
|
||||
class="icon-[lucide--circle-question-mark] text-muted-foreground"
|
||||
/>
|
||||
<span v-if="!isMissingModelResolution" class="text-muted-foreground">
|
||||
{{ $t('assetBrowser.notSureLeaveAsIs') }}
|
||||
</span>
|
||||
</div>
|
||||
<div class="flex items-center gap-2">
|
||||
<label>
|
||||
{{ $t('assetBrowser.modelTypeSelectorLabel') }}
|
||||
</label>
|
||||
<i class="icon-[lucide--circle-question-mark] text-muted-foreground" />
|
||||
<span class="text-muted-foreground">
|
||||
{{ $t('assetBrowser.notSureLeaveAsIs') }}
|
||||
</span>
|
||||
</div>
|
||||
<SingleSelect
|
||||
v-model="modelValue"
|
||||
@@ -75,37 +41,23 @@
|
||||
: $t('assetBrowser.modelTypeSelectorPlaceholder')
|
||||
"
|
||||
:options="modelTypes"
|
||||
:disabled="isLoading || isMissingModelResolution"
|
||||
:disabled="isLoading"
|
||||
:content-style="selectContentStyle"
|
||||
data-attr="upload-model-step2-type-selector"
|
||||
/>
|
||||
<i18n-t
|
||||
v-if="isMissingModelResolution"
|
||||
keypath="assetBrowser.missingModelImportTypeLocked"
|
||||
tag="span"
|
||||
class="text-muted-foreground"
|
||||
>
|
||||
<template #type>
|
||||
<span>{{ selectedModelTypeLabel }}</span>
|
||||
</template>
|
||||
</i18n-t>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed } from 'vue'
|
||||
|
||||
import SingleSelect from '@/components/ui/single-select/SingleSelect.vue'
|
||||
import { usePrimeVueOverlayChildStyle } from '@/composables/usePopoverSizing'
|
||||
import { useModelTypes } from '@/platform/assets/composables/useModelTypes'
|
||||
import type { UploadModelDialogContext } from '@/platform/assets/composables/useUploadModelWizard'
|
||||
import type { AssetMetadata } from '@/platform/assets/schemas/assetSchema'
|
||||
|
||||
const { uploadContext } = defineProps<{
|
||||
defineProps<{
|
||||
metadata?: AssetMetadata
|
||||
previewImage?: string
|
||||
uploadContext?: UploadModelDialogContext
|
||||
}>()
|
||||
|
||||
const modelValue = defineModel<string | undefined>()
|
||||
@@ -113,27 +65,4 @@ const modelValue = defineModel<string | undefined>()
|
||||
const { modelTypes, isLoading } = useModelTypes()
|
||||
const primeVueOverlay = usePrimeVueOverlayChildStyle()
|
||||
const selectContentStyle = primeVueOverlay.contentStyle
|
||||
|
||||
const isMissingModelResolution = computed(
|
||||
() => uploadContext?.kind === 'missing-model-resolution'
|
||||
)
|
||||
const missingModelName = computed(() =>
|
||||
uploadContext?.kind === 'missing-model-resolution'
|
||||
? uploadContext.missingModelName
|
||||
: ''
|
||||
)
|
||||
const replacementTargets = computed(() =>
|
||||
uploadContext?.kind === 'missing-model-resolution'
|
||||
? uploadContext.replacementTargets
|
||||
: []
|
||||
)
|
||||
const selectedModelTypeLabel = computed(() => {
|
||||
const value =
|
||||
uploadContext?.kind === 'missing-model-resolution'
|
||||
? uploadContext.requiredModelType
|
||||
: modelValue.value
|
||||
return (
|
||||
modelTypes.value.find((option) => option.value === value)?.name ?? value
|
||||
)
|
||||
})
|
||||
</script>
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
v-model="selectedModelType"
|
||||
:metadata="wizardData.metadata"
|
||||
:preview-image="wizardData.previewImage"
|
||||
:upload-context="uploadContext"
|
||||
/>
|
||||
|
||||
<!-- Step 3: Upload Progress -->
|
||||
@@ -25,7 +24,6 @@
|
||||
v-else-if="currentStep === 3 && uploadStatus != null"
|
||||
:result="uploadStatus"
|
||||
:error="uploadError"
|
||||
:type-mismatch="uploadTypeMismatch"
|
||||
:metadata="wizardData.metadata"
|
||||
:model-type="selectedModelType"
|
||||
:preview-image="wizardData.previewImage"
|
||||
@@ -41,7 +39,6 @@
|
||||
:can-fetch-metadata="canFetchMetadata"
|
||||
:can-upload-model="canUploadModel"
|
||||
:upload-status="uploadStatus"
|
||||
:can-import-another="!isMissingModelResolution"
|
||||
@back="goToPreviousStep"
|
||||
@fetch-metadata="handleFetchMetadata"
|
||||
@upload="handleUploadModel"
|
||||
@@ -52,47 +49,29 @@
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed, onMounted } from 'vue'
|
||||
import { onMounted } from 'vue'
|
||||
|
||||
import UploadModelConfirmation from '@/platform/assets/components/UploadModelConfirmation.vue'
|
||||
import UploadModelFooter from '@/platform/assets/components/UploadModelFooter.vue'
|
||||
import UploadModelProgress from '@/platform/assets/components/UploadModelProgress.vue'
|
||||
import UploadModelUrlInput from '@/platform/assets/components/UploadModelUrlInput.vue'
|
||||
import { useModelTypes } from '@/platform/assets/composables/useModelTypes'
|
||||
import type {
|
||||
UploadModelDialogContext,
|
||||
UploadModelSuccess
|
||||
} from '@/platform/assets/composables/useUploadModelWizard'
|
||||
import { useUploadModelWizard } from '@/platform/assets/composables/useUploadModelWizard'
|
||||
import { useDialogStore } from '@/stores/dialogStore'
|
||||
|
||||
const dialogStore = useDialogStore()
|
||||
const { modelTypes, fetchModelTypes } = useModelTypes()
|
||||
|
||||
const { uploadContext } = defineProps<{
|
||||
uploadContext?: UploadModelDialogContext
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
'upload-success': [result: UploadModelSuccess]
|
||||
'upload-success': []
|
||||
}>()
|
||||
|
||||
const isMissingModelResolution = computed(
|
||||
() => uploadContext?.kind === 'missing-model-resolution'
|
||||
)
|
||||
const requiredModelType = computed(() =>
|
||||
uploadContext?.kind === 'missing-model-resolution'
|
||||
? uploadContext.requiredModelType
|
||||
: undefined
|
||||
)
|
||||
|
||||
const {
|
||||
currentStep,
|
||||
isFetchingMetadata,
|
||||
isUploading,
|
||||
uploadStatus,
|
||||
uploadError,
|
||||
uploadTypeMismatch,
|
||||
wizardData,
|
||||
selectedModelType,
|
||||
canFetchMetadata,
|
||||
@@ -101,18 +80,16 @@ const {
|
||||
uploadModel,
|
||||
goToPreviousStep,
|
||||
resetWizard
|
||||
} = useUploadModelWizard(modelTypes, {
|
||||
requiredModelType: requiredModelType.value
|
||||
})
|
||||
} = useUploadModelWizard(modelTypes)
|
||||
|
||||
async function handleFetchMetadata() {
|
||||
await fetchMetadata()
|
||||
}
|
||||
|
||||
async function handleUploadModel() {
|
||||
const result = await uploadModel()
|
||||
if (result) {
|
||||
emit('upload-success', result)
|
||||
const success = await uploadModel()
|
||||
if (success) {
|
||||
emit('upload-success')
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
import { render, screen } from '@testing-library/vue'
|
||||
import { createI18n } from 'vue-i18n'
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
import enMessages from '@/locales/en/main.json' with { type: 'json' }
|
||||
|
||||
import UploadModelFooter from './UploadModelFooter.vue'
|
||||
|
||||
const i18n = createI18n({
|
||||
legacy: false,
|
||||
locale: 'en',
|
||||
messages: { en: enMessages },
|
||||
missingWarn: false,
|
||||
fallbackWarn: false
|
||||
})
|
||||
|
||||
function renderFooter(
|
||||
props: Partial<InstanceType<typeof UploadModelFooter>['$props']> = {}
|
||||
) {
|
||||
render(UploadModelFooter, {
|
||||
props: {
|
||||
currentStep: 3,
|
||||
isFetchingMetadata: false,
|
||||
isUploading: false,
|
||||
canFetchMetadata: true,
|
||||
canUploadModel: true,
|
||||
uploadStatus: 'success',
|
||||
...props
|
||||
},
|
||||
global: {
|
||||
plugins: [i18n],
|
||||
stubs: {
|
||||
VideoHelpDialog: true
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
describe('UploadModelFooter', () => {
|
||||
it('allows importing another model by default', () => {
|
||||
renderFooter()
|
||||
|
||||
expect(screen.getByRole('button', { name: 'Import Another' })).toBeEnabled()
|
||||
})
|
||||
|
||||
it('disables importing another model when the upload resolves a missing model', () => {
|
||||
renderFooter({ canImportAnother: false })
|
||||
|
||||
expect(
|
||||
screen.getByRole('button', { name: 'Import Another' })
|
||||
).toBeDisabled()
|
||||
})
|
||||
|
||||
it('shows recovery actions for upload errors', () => {
|
||||
renderFooter({ uploadStatus: 'error' })
|
||||
|
||||
expect(screen.getByRole('button', { name: 'Back' })).toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: 'Close' })).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
@@ -73,7 +73,6 @@
|
||||
variant="muted-textonly"
|
||||
size="lg"
|
||||
data-attr="upload-model-step3-import-another-button"
|
||||
:disabled="!canImportAnother"
|
||||
@click="emit('importAnother')"
|
||||
>
|
||||
{{ $t('assetBrowser.importAnother') }}
|
||||
@@ -91,24 +90,6 @@
|
||||
}}
|
||||
</Button>
|
||||
</template>
|
||||
<template v-else-if="currentStep === 3 && uploadStatus === 'error'">
|
||||
<Button
|
||||
variant="muted-textonly"
|
||||
size="lg"
|
||||
data-attr="upload-model-step3-back-button"
|
||||
@click="emit('back')"
|
||||
>
|
||||
{{ $t('g.back') }}
|
||||
</Button>
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="lg"
|
||||
data-attr="upload-model-step3-close-button"
|
||||
@click="emit('close')"
|
||||
>
|
||||
{{ $t('g.close') }}
|
||||
</Button>
|
||||
</template>
|
||||
<VideoHelpDialog
|
||||
v-model="showCivitaiHelp"
|
||||
video-url="https://media.comfy.org/compressed_768/civitai_howto.webm"
|
||||
@@ -132,14 +113,13 @@ import VideoHelpDialog from '@/platform/assets/components/VideoHelpDialog.vue'
|
||||
const showCivitaiHelp = ref(false)
|
||||
const showHuggingFaceHelp = ref(false)
|
||||
|
||||
const { canImportAnother = true } = defineProps<{
|
||||
defineProps<{
|
||||
currentStep: number
|
||||
isFetchingMetadata: boolean
|
||||
isUploading: boolean
|
||||
canFetchMetadata: boolean
|
||||
canUploadModel: boolean
|
||||
uploadStatus?: 'processing' | 'success' | 'error'
|
||||
canImportAnother?: boolean
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
|
||||
@@ -1,75 +0,0 @@
|
||||
import { render, screen } from '@testing-library/vue'
|
||||
import { createI18n } from 'vue-i18n'
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
import enMessages from '@/locales/en/main.json' with { type: 'json' }
|
||||
|
||||
import UploadModelProgress from './UploadModelProgress.vue'
|
||||
|
||||
const i18n = createI18n({
|
||||
legacy: false,
|
||||
locale: 'en',
|
||||
messages: { en: enMessages },
|
||||
missingWarn: false,
|
||||
fallbackWarn: false,
|
||||
escapeParameter: true
|
||||
})
|
||||
|
||||
describe('UploadModelProgress', () => {
|
||||
it('renders missing-model type mismatch labels without escaped entities', () => {
|
||||
render(UploadModelProgress, {
|
||||
props: {
|
||||
result: 'error',
|
||||
typeMismatch: {
|
||||
importedModelType: 'loras',
|
||||
importedModelTypeLabel: 'LoRA/Custom',
|
||||
requiredModelType: 'Ultralytics/bbox',
|
||||
requiredModelTypeLabel: 'Ultralytics/bbox'
|
||||
}
|
||||
},
|
||||
global: {
|
||||
plugins: [i18n]
|
||||
}
|
||||
})
|
||||
|
||||
expect(
|
||||
screen.getByText('This model cannot resolve the missing model.')
|
||||
).toBeInTheDocument()
|
||||
expect(screen.getByText('LoRA/Custom')).toBeInTheDocument()
|
||||
expect(screen.getAllByText('Ultralytics/bbox').length).toBeGreaterThan(0)
|
||||
expect(
|
||||
screen.getByText((_content, element) => {
|
||||
return (
|
||||
element?.textContent ===
|
||||
'Try importing a different Ultralytics/bbox model that this node can use.'
|
||||
)
|
||||
})
|
||||
).toBeInTheDocument()
|
||||
expect(screen.queryByText(///)).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('uses fallback copy when the imported model type label is unknown', () => {
|
||||
render(UploadModelProgress, {
|
||||
props: {
|
||||
result: 'error',
|
||||
typeMismatch: {
|
||||
requiredModelType: 'checkpoints',
|
||||
requiredModelTypeLabel: 'Checkpoint'
|
||||
}
|
||||
},
|
||||
global: {
|
||||
plugins: [i18n]
|
||||
}
|
||||
})
|
||||
|
||||
expect(screen.getByText('another model type')).toBeInTheDocument()
|
||||
expect(
|
||||
screen.getByText((_content, element) => {
|
||||
return (
|
||||
element?.textContent ===
|
||||
'This file is already imported as another model type.'
|
||||
)
|
||||
})
|
||||
).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
@@ -1,12 +1,5 @@
|
||||
<template>
|
||||
<div
|
||||
:class="
|
||||
cn(
|
||||
'flex flex-1 flex-col gap-6 text-sm text-muted-foreground',
|
||||
isTypeMismatchError && 'min-h-full justify-center'
|
||||
)
|
||||
"
|
||||
>
|
||||
<div class="flex flex-1 flex-col gap-6 text-sm text-muted-foreground">
|
||||
<!-- Processing State (202 async download in progress) -->
|
||||
<div v-if="result === 'processing'" class="flex flex-col gap-2">
|
||||
<p class="m-0 font-bold">
|
||||
@@ -74,51 +67,8 @@
|
||||
v-else-if="result === 'error'"
|
||||
class="flex flex-1 flex-col items-center justify-center gap-6"
|
||||
>
|
||||
<i
|
||||
aria-hidden="true"
|
||||
class="text-error"
|
||||
:class="
|
||||
typeMismatch
|
||||
? 'icon-[lucide--circle-alert] size-12'
|
||||
: 'icon-[lucide--x-circle] size-16'
|
||||
"
|
||||
/>
|
||||
<div
|
||||
v-if="typeMismatch"
|
||||
class="flex max-w-2xl flex-col gap-3 text-center"
|
||||
>
|
||||
<p class="m-0 text-sm font-bold">
|
||||
{{ $t('assetBrowser.missingModelImportTypeMismatchTitle') }}
|
||||
</p>
|
||||
<i18n-t
|
||||
keypath="assetBrowser.missingModelImportTypeMismatchAlreadyImported"
|
||||
tag="p"
|
||||
class="m-0 text-sm text-muted"
|
||||
>
|
||||
<template #actual>
|
||||
<span>{{ actualModelTypeLabel }}</span>
|
||||
</template>
|
||||
</i18n-t>
|
||||
<i18n-t
|
||||
keypath="assetBrowser.missingModelImportTypeMismatchRequired"
|
||||
tag="p"
|
||||
class="m-0 text-sm text-muted"
|
||||
>
|
||||
<template #required>
|
||||
<span>{{ typeMismatch.requiredModelTypeLabel }}</span>
|
||||
</template>
|
||||
</i18n-t>
|
||||
<i18n-t
|
||||
keypath="assetBrowser.missingModelImportTypeMismatchNextAction"
|
||||
tag="p"
|
||||
class="m-0 text-sm text-base-foreground"
|
||||
>
|
||||
<template #required>
|
||||
<span>{{ typeMismatch.requiredModelTypeLabel }}</span>
|
||||
</template>
|
||||
</i18n-t>
|
||||
</div>
|
||||
<div v-else class="text-center">
|
||||
<i class="icon-[lucide--x-circle] text-6xl text-error" />
|
||||
<div class="text-center">
|
||||
<p class="m-0 text-sm font-bold">
|
||||
{{ $t('assetBrowser.uploadFailed') }}
|
||||
</p>
|
||||
@@ -131,26 +81,13 @@
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import { cn } from '@comfyorg/tailwind-utils'
|
||||
import type { AssetMetadata } from '@/platform/assets/schemas/assetSchema'
|
||||
import type { UploadModelTypeMismatch } from '@/platform/assets/composables/useUploadModelWizard'
|
||||
|
||||
const { typeMismatch } = defineProps<{
|
||||
defineProps<{
|
||||
result: 'processing' | 'success' | 'error'
|
||||
error?: string
|
||||
metadata?: AssetMetadata
|
||||
modelType?: string
|
||||
previewImage?: string
|
||||
typeMismatch?: UploadModelTypeMismatch | null
|
||||
}>()
|
||||
|
||||
const { t } = useI18n()
|
||||
const isTypeMismatchError = computed(() => typeMismatch != null)
|
||||
const actualModelTypeLabel = computed(
|
||||
() =>
|
||||
typeMismatch?.importedModelTypeLabel ??
|
||||
t('assetBrowser.missingModelImportUnknownType')
|
||||
)
|
||||
</script>
|
||||
|
||||
@@ -3,28 +3,17 @@ import { computed } from 'vue'
|
||||
import { useFeatureFlags } from '@/composables/useFeatureFlags'
|
||||
import UploadModelDialog from '@/platform/assets/components/UploadModelDialog.vue'
|
||||
import UploadModelDialogHeader from '@/platform/assets/components/UploadModelDialogHeader.vue'
|
||||
import type {
|
||||
UploadModelDialogContext,
|
||||
UploadModelSuccess
|
||||
} from '@/platform/assets/composables/useUploadModelWizard'
|
||||
import UploadModelUpgradeModal from '@/platform/assets/components/UploadModelUpgradeModal.vue'
|
||||
import UploadModelUpgradeModalHeader from '@/platform/assets/components/UploadModelUpgradeModalHeader.vue'
|
||||
import { useDialogStore } from '@/stores/dialogStore'
|
||||
|
||||
type UploadModelContextResolver = () => UploadModelDialogContext | undefined
|
||||
|
||||
export function useModelUpload(
|
||||
onUploadSuccess?: (result: UploadModelSuccess) => Promise<unknown> | void,
|
||||
uploadContext?: UploadModelDialogContext | UploadModelContextResolver
|
||||
onUploadSuccess?: () => Promise<unknown> | void
|
||||
) {
|
||||
const dialogStore = useDialogStore()
|
||||
const { flags } = useFeatureFlags()
|
||||
const isUploadButtonEnabled = computed(() => flags.modelUploadButtonEnabled)
|
||||
|
||||
function resolveUploadContext() {
|
||||
return typeof uploadContext === 'function' ? uploadContext() : uploadContext
|
||||
}
|
||||
|
||||
function showUploadDialog() {
|
||||
if (!flags.privateModelsEnabled) {
|
||||
dialogStore.showDialog({
|
||||
@@ -44,9 +33,8 @@ export function useModelUpload(
|
||||
headerComponent: UploadModelDialogHeader,
|
||||
component: UploadModelDialog,
|
||||
props: {
|
||||
uploadContext: resolveUploadContext(),
|
||||
onUploadSuccess: async (result: UploadModelSuccess) => {
|
||||
await onUploadSuccess?.(result)
|
||||
onUploadSuccess: async () => {
|
||||
await onUploadSuccess?.()
|
||||
}
|
||||
},
|
||||
dialogComponentProps: {
|
||||
|
||||
@@ -1,18 +1,14 @@
|
||||
import { createTestingPinia } from '@pinia/testing'
|
||||
import { setActivePinia } from 'pinia'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { createApp, nextTick, ref } from 'vue'
|
||||
import type { App } from 'vue'
|
||||
import { createI18n } from 'vue-i18n'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { nextTick, ref } from 'vue'
|
||||
|
||||
import enMessages from '@/locales/en/main.json' with { type: 'json' }
|
||||
import type { AsyncUploadResponse } from '@/platform/assets/schemas/assetSchema'
|
||||
|
||||
import { useUploadModelWizard } from './useUploadModelWizard'
|
||||
|
||||
vi.mock('@/platform/assets/services/assetService', () => ({
|
||||
assetService: {
|
||||
getAssetMetadata: vi.fn(),
|
||||
uploadAssetAsync: vi.fn(),
|
||||
uploadAssetPreviewImage: vi.fn()
|
||||
}
|
||||
@@ -49,52 +45,18 @@ vi.mock('@/i18n', () => ({
|
||||
d: (date: Date) => date.toISOString()
|
||||
}))
|
||||
|
||||
vi.mock('vue-i18n', () => ({
|
||||
useI18n: () => ({ t: (key: string) => key })
|
||||
}))
|
||||
|
||||
describe('useUploadModelWizard', () => {
|
||||
const modelTypes = ref([{ name: 'Checkpoint', value: 'checkpoints' }])
|
||||
const mountedApps: App<Element>[] = []
|
||||
|
||||
function setupWithI18n<T>(factory: () => T): T {
|
||||
let result: T | undefined
|
||||
const host = document.createElement('div')
|
||||
const app = createApp({
|
||||
setup() {
|
||||
result = factory()
|
||||
return () => null
|
||||
}
|
||||
})
|
||||
app.use(
|
||||
createI18n({
|
||||
legacy: false,
|
||||
locale: 'en',
|
||||
messages: { en: enMessages }
|
||||
})
|
||||
)
|
||||
app.mount(host)
|
||||
mountedApps.push(app)
|
||||
|
||||
if (result === undefined) {
|
||||
throw new Error('Composable setup did not run')
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
function setupUploadModelWizard(
|
||||
...args: Parameters<typeof useUploadModelWizard>
|
||||
): ReturnType<typeof useUploadModelWizard> {
|
||||
return setupWithI18n(() => useUploadModelWizard(...args))
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
setActivePinia(createTestingPinia({ stubActions: false }))
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
for (const app of mountedApps.splice(0)) {
|
||||
app.unmount()
|
||||
}
|
||||
})
|
||||
|
||||
it('updates uploadStatus to success when async download completes', async () => {
|
||||
const { assetService } =
|
||||
await import('@/platform/assets/services/assetService')
|
||||
@@ -109,18 +71,11 @@ describe('useUploadModelWizard', () => {
|
||||
}
|
||||
vi.mocked(assetService.uploadAssetAsync).mockResolvedValue(asyncResponse)
|
||||
|
||||
const wizard = setupUploadModelWizard(modelTypes)
|
||||
const wizard = useUploadModelWizard(modelTypes)
|
||||
wizard.wizardData.value.url = 'https://civitai.com/models/12345'
|
||||
wizard.selectedModelType.value = 'checkpoints'
|
||||
|
||||
const result = await wizard.uploadModel()
|
||||
|
||||
expect(result).toEqual({
|
||||
filename: 'model',
|
||||
modelType: 'checkpoints',
|
||||
taskId: 'task-123',
|
||||
status: 'processing'
|
||||
})
|
||||
await wizard.uploadModel()
|
||||
|
||||
expect(wizard.uploadStatus.value).toBe('processing')
|
||||
|
||||
@@ -163,7 +118,7 @@ describe('useUploadModelWizard', () => {
|
||||
}
|
||||
vi.mocked(assetService.uploadAssetAsync).mockResolvedValue(asyncResponse)
|
||||
|
||||
const wizard = setupUploadModelWizard(modelTypes)
|
||||
const wizard = useUploadModelWizard(modelTypes)
|
||||
wizard.wizardData.value.url = 'https://civitai.com/models/99999'
|
||||
wizard.selectedModelType.value = 'checkpoints'
|
||||
|
||||
@@ -214,7 +169,7 @@ describe('useUploadModelWizard', () => {
|
||||
}
|
||||
vi.mocked(assetService.uploadAssetAsync).mockResolvedValue(asyncResponse)
|
||||
|
||||
const wizard = setupUploadModelWizard(modelTypes)
|
||||
const wizard = useUploadModelWizard(modelTypes)
|
||||
wizard.wizardData.value.url = 'https://civitai.red/models/12345'
|
||||
wizard.selectedModelType.value = 'checkpoints'
|
||||
|
||||
@@ -223,124 +178,4 @@ describe('useUploadModelWizard', () => {
|
||||
expect(assetService.uploadAssetAsync).toHaveBeenCalled()
|
||||
expect(wizard.uploadStatus.value).toBe('processing')
|
||||
})
|
||||
|
||||
it('keeps a required model type when metadata suggests another type', async () => {
|
||||
const { assetService } =
|
||||
await import('@/platform/assets/services/assetService')
|
||||
vi.mocked(assetService.getAssetMetadata).mockResolvedValue({
|
||||
content_length: 100,
|
||||
final_url: 'https://civitai.com/models/12345',
|
||||
filename: 'lora.safetensors',
|
||||
tags: ['loras']
|
||||
})
|
||||
|
||||
const wizard = setupUploadModelWizard(
|
||||
ref([
|
||||
{ name: 'Checkpoint', value: 'checkpoints' },
|
||||
{ name: 'LoRA', value: 'loras' }
|
||||
]),
|
||||
{ requiredModelType: 'checkpoints' }
|
||||
)
|
||||
wizard.wizardData.value.url = 'https://civitai.com/models/12345'
|
||||
|
||||
await wizard.fetchMetadata()
|
||||
|
||||
expect(wizard.selectedModelType.value).toBe('checkpoints')
|
||||
})
|
||||
|
||||
it('uploads with the required model type even if selection changes', async () => {
|
||||
const { assetService } =
|
||||
await import('@/platform/assets/services/assetService')
|
||||
vi.mocked(assetService.uploadAssetAsync).mockResolvedValue({
|
||||
type: 'sync',
|
||||
asset: {
|
||||
id: 'asset-1',
|
||||
name: 'model.safetensors',
|
||||
tags: ['models', 'checkpoints']
|
||||
}
|
||||
})
|
||||
|
||||
const wizard = setupUploadModelWizard(modelTypes, {
|
||||
requiredModelType: 'checkpoints'
|
||||
})
|
||||
wizard.wizardData.value.url = 'https://civitai.com/models/12345'
|
||||
wizard.selectedModelType.value = 'loras'
|
||||
|
||||
const result = await wizard.uploadModel()
|
||||
|
||||
expect(assetService.uploadAssetAsync).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
tags: ['models', 'checkpoints'],
|
||||
user_metadata: expect.objectContaining({
|
||||
model_type: 'checkpoints'
|
||||
})
|
||||
})
|
||||
)
|
||||
expect(result?.modelType).toBe('checkpoints')
|
||||
})
|
||||
|
||||
it('blocks a missing-model import when an existing asset has the wrong model type', async () => {
|
||||
const { assetService } =
|
||||
await import('@/platform/assets/services/assetService')
|
||||
vi.mocked(assetService.uploadAssetAsync).mockResolvedValue({
|
||||
type: 'sync',
|
||||
asset: {
|
||||
id: 'asset-lora',
|
||||
name: 'model.safetensors',
|
||||
tags: ['models', 'loras']
|
||||
}
|
||||
})
|
||||
|
||||
const wizard = setupUploadModelWizard(
|
||||
ref([
|
||||
{ name: 'Checkpoint', value: 'checkpoints' },
|
||||
{ name: 'LoRA', value: 'loras' }
|
||||
]),
|
||||
{ requiredModelType: 'checkpoints' }
|
||||
)
|
||||
wizard.wizardData.value.url = 'https://civitai.com/models/12345'
|
||||
|
||||
const result = await wizard.uploadModel()
|
||||
|
||||
expect(result).toBeNull()
|
||||
expect(wizard.uploadStatus.value).toBe('error')
|
||||
expect(wizard.uploadTypeMismatch.value).toEqual({
|
||||
importedModelType: 'loras',
|
||||
importedModelTypeLabel: 'LoRA',
|
||||
requiredModelType: 'checkpoints',
|
||||
requiredModelTypeLabel: 'Checkpoint'
|
||||
})
|
||||
})
|
||||
|
||||
it('keeps generic sync imports successful when an existing asset has another model type', async () => {
|
||||
const { assetService } =
|
||||
await import('@/platform/assets/services/assetService')
|
||||
vi.mocked(assetService.uploadAssetAsync).mockResolvedValue({
|
||||
type: 'sync',
|
||||
asset: {
|
||||
id: 'asset-lora',
|
||||
name: 'model.safetensors',
|
||||
tags: ['models', 'loras']
|
||||
}
|
||||
})
|
||||
|
||||
const wizard = setupUploadModelWizard(
|
||||
ref([
|
||||
{ name: 'Checkpoint', value: 'checkpoints' },
|
||||
{ name: 'LoRA', value: 'loras' }
|
||||
])
|
||||
)
|
||||
wizard.wizardData.value.url = 'https://civitai.com/models/12345'
|
||||
wizard.selectedModelType.value = 'checkpoints'
|
||||
|
||||
const result = await wizard.uploadModel()
|
||||
|
||||
expect(result).toEqual({
|
||||
filename: 'model',
|
||||
modelType: 'checkpoints',
|
||||
status: 'success'
|
||||
})
|
||||
expect(wizard.uploadStatus.value).toBe('success')
|
||||
expect(wizard.uploadTypeMismatch.value).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -5,10 +5,7 @@ import { useI18n } from 'vue-i18n'
|
||||
import { st } from '@/i18n'
|
||||
import { civitaiImportSource } from '@/platform/assets/importSources/civitaiImportSource'
|
||||
import { huggingfaceImportSource } from '@/platform/assets/importSources/huggingfaceImportSource'
|
||||
import type {
|
||||
AssetItem,
|
||||
AssetMetadata
|
||||
} from '@/platform/assets/schemas/assetSchema'
|
||||
import type { AssetMetadata } from '@/platform/assets/schemas/assetSchema'
|
||||
import { assetService } from '@/platform/assets/services/assetService'
|
||||
import type { ImportSource } from '@/platform/assets/types/importSource'
|
||||
import { validateSourceUrl } from '@/platform/assets/utils/importSourceUtil'
|
||||
@@ -29,54 +26,16 @@ interface ModelTypeOption {
|
||||
value: string
|
||||
}
|
||||
|
||||
const MODEL_ROOT_TAG = 'models'
|
||||
|
||||
export interface UploadModelSuccess {
|
||||
filename: string
|
||||
modelType?: string
|
||||
taskId?: string
|
||||
status: 'processing' | 'success'
|
||||
}
|
||||
|
||||
export interface UploadModelTypeMismatch {
|
||||
importedModelType?: string
|
||||
importedModelTypeLabel?: string
|
||||
requiredModelType: string
|
||||
requiredModelTypeLabel: string
|
||||
}
|
||||
|
||||
interface MissingModelUploadContext {
|
||||
kind: 'missing-model-resolution'
|
||||
missingModelName: string
|
||||
requiredModelType: string
|
||||
replacementTargets: Array<{
|
||||
nodeId: string
|
||||
nodeLabel: string
|
||||
widgetName: string
|
||||
}>
|
||||
}
|
||||
|
||||
export type UploadModelDialogContext = MissingModelUploadContext
|
||||
|
||||
interface UploadModelWizardOptions {
|
||||
requiredModelType?: string
|
||||
}
|
||||
|
||||
export function useUploadModelWizard(
|
||||
modelTypes: Ref<ModelTypeOption[]>,
|
||||
options: UploadModelWizardOptions = {}
|
||||
) {
|
||||
export function useUploadModelWizard(modelTypes: Ref<ModelTypeOption[]>) {
|
||||
const { t } = useI18n()
|
||||
const assetsStore = useAssetsStore()
|
||||
const assetDownloadStore = useAssetDownloadStore()
|
||||
const modelToNodeStore = useModelToNodeStore()
|
||||
const requiredModelType = options.requiredModelType
|
||||
const currentStep = ref(1)
|
||||
const isFetchingMetadata = ref(false)
|
||||
const isUploading = ref(false)
|
||||
const uploadStatus = ref<'processing' | 'success' | 'error'>()
|
||||
const uploadError = ref('')
|
||||
const uploadTypeMismatch = ref<UploadModelTypeMismatch | null>(null)
|
||||
let stopAsyncWatch: (() => void) | undefined
|
||||
|
||||
const wizardData = ref<WizardData>({
|
||||
@@ -85,10 +44,7 @@ export function useUploadModelWizard(
|
||||
tags: []
|
||||
})
|
||||
|
||||
const selectedModelType = ref<string | undefined>(requiredModelType)
|
||||
const resolvedModelType = computed(
|
||||
() => requiredModelType ?? selectedModelType.value
|
||||
)
|
||||
const selectedModelType = ref<string>()
|
||||
|
||||
const importSources: ImportSource[] = [
|
||||
civitaiImportSource,
|
||||
@@ -109,29 +65,16 @@ export function useUploadModelWizard(
|
||||
() => wizardData.value.url,
|
||||
() => {
|
||||
uploadError.value = ''
|
||||
uploadTypeMismatch.value = null
|
||||
}
|
||||
)
|
||||
|
||||
if (requiredModelType) {
|
||||
watch(
|
||||
selectedModelType,
|
||||
(value) => {
|
||||
if (value !== requiredModelType) {
|
||||
selectedModelType.value = requiredModelType
|
||||
}
|
||||
},
|
||||
{ immediate: true }
|
||||
)
|
||||
}
|
||||
|
||||
// Validation - only enable Continue when URL matches a supported source
|
||||
const canFetchMetadata = computed(() => {
|
||||
return detectedSource.value !== null
|
||||
})
|
||||
|
||||
const canUploadModel = computed(() => {
|
||||
return !!resolvedModelType.value
|
||||
return !!selectedModelType.value
|
||||
})
|
||||
|
||||
async function fetchMetadata() {
|
||||
@@ -185,9 +128,7 @@ export function useUploadModelWizard(
|
||||
wizardData.value.previewImage = metadata.preview_image
|
||||
|
||||
// Pre-fill model type from metadata tags if available
|
||||
if (requiredModelType) {
|
||||
selectedModelType.value = requiredModelType
|
||||
} else if (metadata.tags && metadata.tags.length > 0) {
|
||||
if (metadata.tags && metadata.tags.length > 0) {
|
||||
wizardData.value.tags = metadata.tags
|
||||
// Try to detect model type from tags
|
||||
const typeTag = metadata.tags.find((tag) =>
|
||||
@@ -242,10 +183,10 @@ export function useUploadModelWizard(
|
||||
}
|
||||
|
||||
async function refreshModelCaches() {
|
||||
if (!resolvedModelType.value) return
|
||||
if (!selectedModelType.value) return
|
||||
|
||||
const providers = modelToNodeStore.getAllNodeProviders(
|
||||
resolvedModelType.value
|
||||
selectedModelType.value
|
||||
)
|
||||
const results = await Promise.allSettled(
|
||||
providers.map((provider) =>
|
||||
@@ -262,61 +203,24 @@ export function useUploadModelWizard(
|
||||
})
|
||||
}
|
||||
|
||||
function getModelTypeLabel(modelType: string): string {
|
||||
return (
|
||||
modelTypes.value.find((type) => type.value === modelType)?.name ??
|
||||
modelType
|
||||
)
|
||||
}
|
||||
|
||||
function getImportedModelType(asset: AssetItem): string | undefined {
|
||||
const knownType = asset.tags.find(
|
||||
(tag) =>
|
||||
tag !== MODEL_ROOT_TAG &&
|
||||
modelTypes.value.some((type) => type.value === tag)
|
||||
)
|
||||
return knownType ?? asset.tags.find((tag) => tag !== MODEL_ROOT_TAG)
|
||||
}
|
||||
|
||||
function blockMismatchedImportedModel(
|
||||
asset: AssetItem,
|
||||
requiredType: string
|
||||
): boolean {
|
||||
if (asset.tags.includes(requiredType)) return false
|
||||
|
||||
const importedType = getImportedModelType(asset)
|
||||
uploadStatus.value = 'error'
|
||||
uploadError.value = ''
|
||||
uploadTypeMismatch.value = {
|
||||
importedModelType: importedType,
|
||||
importedModelTypeLabel: importedType
|
||||
? getModelTypeLabel(importedType)
|
||||
: undefined,
|
||||
requiredModelType: requiredType,
|
||||
requiredModelTypeLabel: getModelTypeLabel(requiredType)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
async function uploadModel(): Promise<UploadModelSuccess | null> {
|
||||
if (isUploading.value) return null
|
||||
async function uploadModel(): Promise<boolean> {
|
||||
if (isUploading.value) return false
|
||||
if (!canUploadModel.value) {
|
||||
return null
|
||||
return false
|
||||
}
|
||||
|
||||
const source = detectedSource.value
|
||||
if (!source) {
|
||||
uploadError.value = t('assetBrowser.noValidSourceDetected')
|
||||
return null
|
||||
return false
|
||||
}
|
||||
|
||||
isUploading.value = true
|
||||
uploadTypeMismatch.value = null
|
||||
let uploadSuccess: UploadModelSuccess | null = null
|
||||
|
||||
try {
|
||||
const modelType = resolvedModelType.value
|
||||
const tags = modelType ? ['models', modelType] : ['models']
|
||||
const tags = selectedModelType.value
|
||||
? ['models', selectedModelType.value]
|
||||
: ['models']
|
||||
const filename =
|
||||
wizardData.value.metadata?.filename ||
|
||||
wizardData.value.metadata?.name ||
|
||||
@@ -326,7 +230,7 @@ export function useUploadModelWizard(
|
||||
const userMetadata = {
|
||||
source: source.type,
|
||||
source_url: wizardData.value.url,
|
||||
model_type: modelType
|
||||
model_type: selectedModelType.value
|
||||
}
|
||||
|
||||
const result = await assetService.uploadAssetAsync({
|
||||
@@ -337,20 +241,14 @@ export function useUploadModelWizard(
|
||||
})
|
||||
|
||||
if (result.type === 'async' && result.task.status !== 'completed') {
|
||||
if (modelType) {
|
||||
if (selectedModelType.value) {
|
||||
assetDownloadStore.trackDownload(
|
||||
result.task.task_id,
|
||||
modelType,
|
||||
selectedModelType.value,
|
||||
filename
|
||||
)
|
||||
}
|
||||
uploadStatus.value = 'processing'
|
||||
uploadSuccess = {
|
||||
filename,
|
||||
modelType,
|
||||
taskId: result.task.task_id,
|
||||
status: 'processing'
|
||||
}
|
||||
|
||||
stopAsyncWatch?.()
|
||||
let resolved = false
|
||||
@@ -390,23 +288,8 @@ export function useUploadModelWizard(
|
||||
stopAsyncWatch = stop
|
||||
}
|
||||
} else {
|
||||
if (
|
||||
requiredModelType &&
|
||||
result.type === 'sync' &&
|
||||
modelType &&
|
||||
blockMismatchedImportedModel(result.asset, modelType)
|
||||
) {
|
||||
currentStep.value = 3
|
||||
return null
|
||||
}
|
||||
|
||||
uploadStatus.value = 'success'
|
||||
await refreshModelCaches()
|
||||
uploadSuccess = {
|
||||
filename,
|
||||
modelType,
|
||||
status: 'success'
|
||||
}
|
||||
}
|
||||
currentStep.value = 3
|
||||
} catch (error) {
|
||||
@@ -418,7 +301,7 @@ export function useUploadModelWizard(
|
||||
} finally {
|
||||
isUploading.value = false
|
||||
}
|
||||
return uploadSuccess
|
||||
return uploadStatus.value !== 'error'
|
||||
}
|
||||
|
||||
function goToPreviousStep() {
|
||||
@@ -435,13 +318,12 @@ export function useUploadModelWizard(
|
||||
isUploading.value = false
|
||||
uploadStatus.value = undefined
|
||||
uploadError.value = ''
|
||||
uploadTypeMismatch.value = null
|
||||
wizardData.value = {
|
||||
url: '',
|
||||
name: '',
|
||||
tags: []
|
||||
}
|
||||
selectedModelType.value = requiredModelType
|
||||
selectedModelType.value = undefined
|
||||
}
|
||||
|
||||
return {
|
||||
@@ -451,7 +333,6 @@ export function useUploadModelWizard(
|
||||
isUploading,
|
||||
uploadStatus,
|
||||
uploadError,
|
||||
uploadTypeMismatch,
|
||||
wizardData,
|
||||
selectedModelType,
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ const zAsset = z.object({
|
||||
})
|
||||
|
||||
const zAssetResponse = zListAssetsResponse
|
||||
.pick({ total: true, has_more: true, next_cursor: true })
|
||||
.pick({ total: true, has_more: true })
|
||||
.extend({
|
||||
assets: z.array(zAsset)
|
||||
})
|
||||
|
||||
@@ -53,7 +53,6 @@ const fetchApiMock = vi.mocked(api.fetchApi)
|
||||
type AssetListResponseOptions = {
|
||||
hasMore?: AssetResponse['has_more']
|
||||
total?: AssetResponse['total']
|
||||
nextCursor?: AssetResponse['next_cursor']
|
||||
}
|
||||
|
||||
function buildResponse(
|
||||
@@ -69,18 +68,9 @@ function buildResponse(
|
||||
|
||||
function buildAssetListResponse(
|
||||
assets: AssetItem[],
|
||||
{
|
||||
hasMore = false,
|
||||
total = assets.length,
|
||||
nextCursor
|
||||
}: AssetListResponseOptions = {}
|
||||
{ hasMore = false, total = assets.length }: AssetListResponseOptions = {}
|
||||
): Response {
|
||||
return buildResponse({
|
||||
assets,
|
||||
total,
|
||||
has_more: hasMore,
|
||||
...(nextCursor === undefined ? {} : { next_cursor: nextCursor })
|
||||
})
|
||||
return buildResponse({ assets, total, has_more: hasMore })
|
||||
}
|
||||
|
||||
function validAsset(overrides: Partial<AssetItem> = {}): AssetItem {
|
||||
@@ -522,7 +512,7 @@ describe(assetService.getAllAssetsByTag, () => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('walks pages by keyset cursor with include_public=true', async () => {
|
||||
it('paginates tagged asset requests with include_public=true', async () => {
|
||||
fetchApiMock
|
||||
.mockResolvedValueOnce(
|
||||
buildAssetListResponse(
|
||||
@@ -530,7 +520,7 @@ describe(assetService.getAllAssetsByTag, () => {
|
||||
validAsset({ id: 'a', tags: ['input'] }),
|
||||
validAsset({ id: 'b', tags: ['input'] })
|
||||
],
|
||||
{ hasMore: true, nextCursor: 'cursor-page-2' }
|
||||
{ hasMore: true }
|
||||
)
|
||||
)
|
||||
.mockResolvedValueOnce(
|
||||
@@ -548,8 +538,6 @@ describe(assetService.getAllAssetsByTag, () => {
|
||||
expect(firstParams.get('include_public')).toBe('true')
|
||||
expect(firstParams.get('exclude_tags')).toBe(MISSING_TAG)
|
||||
expect(firstParams.get('limit')).toBe('2')
|
||||
// First page carries neither a cursor nor an offset.
|
||||
expect(firstParams.has('after')).toBe(false)
|
||||
expect(firstParams.has('offset')).toBe(false)
|
||||
|
||||
const secondUrl = fetchApiMock.mock.calls[1]?.[0] as string
|
||||
@@ -557,9 +545,7 @@ describe(assetService.getAllAssetsByTag, () => {
|
||||
expect(secondParams.get('include_public')).toBe('true')
|
||||
expect(secondParams.get('exclude_tags')).toBe(MISSING_TAG)
|
||||
expect(secondParams.get('limit')).toBe('2')
|
||||
// Subsequent pages resume from the prior response's next_cursor, never offset.
|
||||
expect(secondParams.get('after')).toBe('cursor-page-2')
|
||||
expect(secondParams.has('offset')).toBe(false)
|
||||
expect(secondParams.get('offset')).toBe('2')
|
||||
})
|
||||
|
||||
it('honors has_more when walking tagged asset pages', async () => {
|
||||
@@ -570,7 +556,7 @@ describe(assetService.getAllAssetsByTag, () => {
|
||||
validAsset({ id: 'first', tags: ['input'] }),
|
||||
validAsset({ id: 'second', tags: ['input'] })
|
||||
],
|
||||
{ hasMore: true, nextCursor: 'cursor-next' }
|
||||
{ hasMore: true }
|
||||
)
|
||||
)
|
||||
.mockResolvedValueOnce(
|
||||
@@ -591,45 +577,7 @@ describe(assetService.getAllAssetsByTag, () => {
|
||||
throw new Error('Expected a second asset request URL')
|
||||
}
|
||||
const secondParams = new URL(secondUrl, 'http://localhost').searchParams
|
||||
expect(secondParams.get('after')).toBe('cursor-next')
|
||||
})
|
||||
|
||||
it('stops walking when next_cursor is absent even if has_more is true', async () => {
|
||||
fetchApiMock.mockResolvedValueOnce(
|
||||
buildAssetListResponse([validAsset({ id: 'only', tags: ['input'] })], {
|
||||
hasMore: true
|
||||
})
|
||||
)
|
||||
|
||||
const assets = await assetService.getAllAssetsByTag('input', true, {
|
||||
limit: 2
|
||||
})
|
||||
|
||||
expect(assets.map((a) => a.id)).toEqual(['only'])
|
||||
expect(fetchApiMock).toHaveBeenCalledOnce()
|
||||
})
|
||||
|
||||
it('stops walking when the server returns a non-advancing cursor', async () => {
|
||||
fetchApiMock
|
||||
.mockResolvedValueOnce(
|
||||
buildAssetListResponse([validAsset({ id: 'a', tags: ['input'] })], {
|
||||
hasMore: true,
|
||||
nextCursor: 'stuck'
|
||||
})
|
||||
)
|
||||
.mockResolvedValueOnce(
|
||||
buildAssetListResponse([validAsset({ id: 'b', tags: ['input'] })], {
|
||||
hasMore: true,
|
||||
nextCursor: 'stuck'
|
||||
})
|
||||
)
|
||||
|
||||
const assets = await assetService.getAllAssetsByTag('input', true, {
|
||||
limit: 1
|
||||
})
|
||||
|
||||
expect(assets.map((a) => a.id)).toEqual(['a', 'b'])
|
||||
expect(fetchApiMock).toHaveBeenCalledTimes(2)
|
||||
expect(secondParams.get('offset')).toBe('2')
|
||||
})
|
||||
|
||||
it.for([
|
||||
@@ -688,7 +636,7 @@ describe(assetService.getAllAssetsByTag, () => {
|
||||
validAsset({ id: 'a', tags: ['input'] }),
|
||||
validAsset({ id: 'b', tags: ['input'] })
|
||||
],
|
||||
{ hasMore: true, nextCursor: 'cursor-page-2' }
|
||||
{ hasMore: true }
|
||||
)
|
||||
})
|
||||
|
||||
|
||||
@@ -31,11 +31,6 @@ export interface PaginationOptions {
|
||||
}
|
||||
|
||||
interface AssetPaginationOptions extends PaginationOptions {
|
||||
/**
|
||||
* Opaque keyset cursor from a prior response's `next_cursor`. When set, the
|
||||
* server resumes after that cursor and `offset` is ignored.
|
||||
*/
|
||||
after?: string
|
||||
signal?: AbortSignal
|
||||
}
|
||||
|
||||
@@ -43,7 +38,6 @@ interface AssetRequestOptions extends PaginationOptions {
|
||||
includeTags: string[]
|
||||
excludeTags?: string[]
|
||||
includePublic?: boolean
|
||||
after?: string
|
||||
signal?: AbortSignal
|
||||
}
|
||||
|
||||
@@ -292,7 +286,6 @@ function createAssetService() {
|
||||
excludeTags = DEFAULT_EXCLUDED_ASSET_TAGS,
|
||||
limit = DEFAULT_LIMIT,
|
||||
offset,
|
||||
after,
|
||||
includePublic,
|
||||
signal
|
||||
} = options
|
||||
@@ -306,11 +299,7 @@ function createAssetService() {
|
||||
if (normalizedExcludeTags.length > 0) {
|
||||
queryParams.set('exclude_tags', normalizedExcludeTags.join(','))
|
||||
}
|
||||
// `after` (keyset cursor) takes precedence over `offset`; the server ignores
|
||||
// `offset` when a cursor is supplied, so we avoid sending a redundant param.
|
||||
if (after) {
|
||||
queryParams.set('after', after)
|
||||
} else if (offset !== undefined && offset > 0) {
|
||||
if (offset !== undefined && offset > 0) {
|
||||
queryParams.set('offset', offset.toString())
|
||||
}
|
||||
if (includePublic !== undefined) {
|
||||
@@ -492,17 +481,11 @@ function createAssetService() {
|
||||
async function getAssetsByTag(
|
||||
tag: string,
|
||||
includePublic: boolean = true,
|
||||
{
|
||||
limit = DEFAULT_LIMIT,
|
||||
offset = 0,
|
||||
after,
|
||||
signal
|
||||
}: AssetPaginationOptions = {}
|
||||
{ limit = DEFAULT_LIMIT, offset = 0, signal }: AssetPaginationOptions = {}
|
||||
): Promise<AssetItem[]> {
|
||||
const data = await getAssetsPageByTag(tag, includePublic, {
|
||||
limit,
|
||||
offset,
|
||||
after,
|
||||
signal
|
||||
})
|
||||
|
||||
@@ -515,27 +498,17 @@ function createAssetService() {
|
||||
async function getAssetsPageByTag(
|
||||
tag: string,
|
||||
includePublic: boolean = true,
|
||||
{
|
||||
limit = DEFAULT_LIMIT,
|
||||
offset = 0,
|
||||
after,
|
||||
signal
|
||||
}: AssetPaginationOptions = {}
|
||||
{ limit = DEFAULT_LIMIT, offset = 0, signal }: AssetPaginationOptions = {}
|
||||
): Promise<AssetResponse> {
|
||||
return await handleAssetRequest(
|
||||
{ includeTags: [tag], limit, offset, after, includePublic, signal },
|
||||
{ includeTags: [tag], limit, offset, includePublic, signal },
|
||||
`assets for tag ${tag}`
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets every asset for a tag by walking paginated asset API responses.
|
||||
*
|
||||
* Uses keyset (cursor) pagination: each page is fetched with the prior
|
||||
* response's `next_cursor`, which is stable under concurrent inserts/deletes
|
||||
* and avoids the duplicate/skip drift that offset paging exhibits when the
|
||||
* underlying set changes mid-walk. Falls back to terminating on `has_more`
|
||||
* when the server omits `next_cursor`.
|
||||
* Pagination follows the required server-provided `has_more` flag.
|
||||
*
|
||||
* @param tag - The tag to filter by (e.g., 'models', 'input')
|
||||
* @param includePublic - Whether to include public assets (default: true)
|
||||
@@ -547,21 +520,18 @@ function createAssetService() {
|
||||
async function getAllAssetsByTag(
|
||||
tag: string,
|
||||
includePublic: boolean = true,
|
||||
{
|
||||
limit = DEFAULT_LIMIT,
|
||||
signal
|
||||
}: Pick<AssetPaginationOptions, 'limit' | 'signal'> = {}
|
||||
{ limit = DEFAULT_LIMIT, signal }: AssetPaginationOptions = {}
|
||||
): Promise<AssetItem[]> {
|
||||
const assets: AssetItem[] = []
|
||||
const pageSize = limit > 0 ? limit : DEFAULT_LIMIT
|
||||
let after: string | undefined
|
||||
let offset = 0
|
||||
|
||||
while (true) {
|
||||
if (signal?.aborted) throw createAbortError()
|
||||
|
||||
const data = await getAssetsPageByTag(tag, includePublic, {
|
||||
limit: pageSize,
|
||||
after,
|
||||
offset,
|
||||
signal
|
||||
})
|
||||
const batch = data.assets
|
||||
@@ -571,12 +541,11 @@ function createAssetService() {
|
||||
|
||||
assets.push(...batch)
|
||||
|
||||
// A server that returns a non-advancing cursor would loop forever.
|
||||
if (!data.has_more || !data.next_cursor || data.next_cursor === after) {
|
||||
if (!data.has_more) {
|
||||
return assets
|
||||
}
|
||||
|
||||
after = data.next_cursor
|
||||
offset += batch.length
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,439 +0,0 @@
|
||||
// oxlint-disable no-misused-spread
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { markRaw } from 'vue'
|
||||
import type { Raw } from 'vue'
|
||||
|
||||
import type { LGraphNode, Subgraph } from '@/lib/litegraph/src/litegraph'
|
||||
import { LiteGraph } from '@/lib/litegraph/src/litegraph'
|
||||
import type * as LitegraphModule from '@/lib/litegraph/src/litegraph'
|
||||
import type * as ModelToNodeStoreModule from '@/stores/modelToNodeStore'
|
||||
import type * as WorkflowStoreModule from '@/platform/workflow/management/stores/workflowStore'
|
||||
import type * as LitegraphServiceModule from '@/services/litegraphService'
|
||||
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
|
||||
import { createModelNodeFromAsset } from '@/platform/assets/utils/createModelNodeFromAsset'
|
||||
import { useWorkflowStore } from '@/platform/workflow/management/stores/workflowStore'
|
||||
import { app } from '@/scripts/app'
|
||||
import { useLitegraphService } from '@/services/litegraphService'
|
||||
import { useModelToNodeStore } from '@/stores/modelToNodeStore'
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('@/scripts/api', () => ({
|
||||
api: {
|
||||
apiURL: vi.fn((path: string) => `http://localhost:8188${path}`)
|
||||
}
|
||||
}))
|
||||
vi.mock('@/stores/modelToNodeStore', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof ModelToNodeStoreModule>()
|
||||
return {
|
||||
...actual,
|
||||
useModelToNodeStore: vi.fn()
|
||||
}
|
||||
})
|
||||
vi.mock(
|
||||
'@/platform/workflow/management/stores/workflowStore',
|
||||
async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof WorkflowStoreModule>()
|
||||
return {
|
||||
...actual,
|
||||
useWorkflowStore: vi.fn()
|
||||
}
|
||||
}
|
||||
)
|
||||
vi.mock('@/services/litegraphService', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof LitegraphServiceModule>()
|
||||
return {
|
||||
...actual,
|
||||
useLitegraphService: vi.fn()
|
||||
}
|
||||
})
|
||||
vi.mock('@/lib/litegraph/src/litegraph', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof LitegraphModule>()
|
||||
return {
|
||||
...actual,
|
||||
LiteGraph: {
|
||||
...actual.LiteGraph,
|
||||
createNode: vi.fn()
|
||||
}
|
||||
}
|
||||
})
|
||||
vi.mock('@/scripts/app', () => ({
|
||||
app: {
|
||||
canvas: {
|
||||
graph: {
|
||||
add: vi.fn()
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
function createMockAsset(overrides: Partial<AssetItem> = {}): AssetItem {
|
||||
return {
|
||||
id: 'asset-123',
|
||||
name: 'test-model.safetensors',
|
||||
size: 1024,
|
||||
created_at: '2025-10-01T00:00:00Z',
|
||||
tags: ['models', 'checkpoints'],
|
||||
user_metadata: {
|
||||
filename: 'models/checkpoints/test-model.safetensors'
|
||||
},
|
||||
...overrides
|
||||
}
|
||||
}
|
||||
async function createMockNode(overrides?: {
|
||||
widgetName?: string
|
||||
widgetValue?: string
|
||||
hasWidgets?: boolean
|
||||
}): Promise<LGraphNode> {
|
||||
const {
|
||||
widgetName = 'ckpt_name',
|
||||
widgetValue = '',
|
||||
hasWidgets = true
|
||||
} = overrides || {}
|
||||
|
||||
const { LGraphNode: ActualLGraphNode } = await vi.importActual<
|
||||
typeof LitegraphModule
|
||||
>('@/lib/litegraph/src/litegraph')
|
||||
|
||||
if (!hasWidgets) {
|
||||
return Object.create(ActualLGraphNode.prototype)
|
||||
}
|
||||
|
||||
type Widget = NonNullable<LGraphNode['widgets']>[number]
|
||||
const widget: Pick<Widget, 'name' | 'value' | 'type' | 'options' | 'y'> = {
|
||||
name: widgetName,
|
||||
value: widgetValue,
|
||||
type: 'string',
|
||||
options: {},
|
||||
y: 0
|
||||
}
|
||||
|
||||
return Object.create(ActualLGraphNode.prototype, {
|
||||
widgets: { value: [widget], writable: true }
|
||||
})
|
||||
}
|
||||
function createMockNodeProvider(
|
||||
overrides: {
|
||||
nodeDef?: { name: string; display_name: string }
|
||||
key?: string
|
||||
} = {}
|
||||
) {
|
||||
return {
|
||||
nodeDef: {
|
||||
name: 'CheckpointLoaderSimple',
|
||||
display_name: 'Load Checkpoint',
|
||||
...overrides.nodeDef
|
||||
},
|
||||
key: overrides.key ?? 'ckpt_name'
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Configures all mocked dependencies with sensible defaults.
|
||||
* Uses semantic parameters for clearer test intent.
|
||||
* For error paths or edge cases, pass null values or specific overrides.
|
||||
*/
|
||||
async function setupMocks(
|
||||
overrides: {
|
||||
nodeProvider?: ReturnType<typeof createMockNodeProvider> | null
|
||||
canvasCenter?: [number, number]
|
||||
activeSubgraph?: Raw<Subgraph>
|
||||
createdNode?: Awaited<ReturnType<typeof createMockNode>> | null
|
||||
} = {}
|
||||
) {
|
||||
const {
|
||||
nodeProvider = createMockNodeProvider(),
|
||||
canvasCenter = [100, 200],
|
||||
activeSubgraph = undefined,
|
||||
createdNode = await createMockNode()
|
||||
} = overrides
|
||||
|
||||
vi.mocked(useModelToNodeStore).mockReturnValue({
|
||||
...useModelToNodeStore(),
|
||||
getNodeProvider: vi.fn().mockReturnValue(nodeProvider)
|
||||
})
|
||||
|
||||
vi.mocked(useLitegraphService).mockReturnValue({
|
||||
...useLitegraphService(),
|
||||
getCanvasCenter: vi.fn().mockReturnValue(canvasCenter)
|
||||
})
|
||||
|
||||
vi.mocked(useWorkflowStore).mockReturnValue({
|
||||
...useWorkflowStore(),
|
||||
activeSubgraph,
|
||||
isSubgraphActive: !!activeSubgraph
|
||||
})
|
||||
vi.mocked(LiteGraph.createNode).mockReturnValue(createdNode)
|
||||
}
|
||||
describe('createModelNodeFromAsset', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
})
|
||||
describe('when creating nodes from valid assets', () => {
|
||||
it('should create the appropriate loader node for the asset category', async () => {
|
||||
const asset = createMockAsset()
|
||||
await setupMocks()
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(
|
||||
vi.mocked(useModelToNodeStore)().getNodeProvider
|
||||
).toHaveBeenCalledWith('checkpoints')
|
||||
expect(LiteGraph.createNode).toHaveBeenCalledWith(
|
||||
'CheckpointLoaderSimple',
|
||||
'Load Checkpoint',
|
||||
{ pos: [100, 200] }
|
||||
)
|
||||
}
|
||||
})
|
||||
it('should place node at canvas center by default', async () => {
|
||||
const asset = createMockAsset()
|
||||
await setupMocks({
|
||||
canvasCenter: [150, 250]
|
||||
})
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(true)
|
||||
expect(
|
||||
vi.mocked(useLitegraphService)().getCanvasCenter
|
||||
).toHaveBeenCalled()
|
||||
expect(LiteGraph.createNode).toHaveBeenCalledWith(
|
||||
'CheckpointLoaderSimple',
|
||||
'Load Checkpoint',
|
||||
{ pos: [150, 250] }
|
||||
)
|
||||
})
|
||||
it('should place node at specified position when position is provided', async () => {
|
||||
const asset = createMockAsset()
|
||||
await setupMocks()
|
||||
const result = createModelNodeFromAsset(asset, { position: [300, 400] })
|
||||
expect(result.success).toBe(true)
|
||||
expect(
|
||||
vi.mocked(useLitegraphService)().getCanvasCenter
|
||||
).not.toHaveBeenCalled()
|
||||
expect(LiteGraph.createNode).toHaveBeenCalledWith(
|
||||
'CheckpointLoaderSimple',
|
||||
'Load Checkpoint',
|
||||
{ pos: [300, 400] }
|
||||
)
|
||||
})
|
||||
it('should populate the loader widget with the asset file path', async () => {
|
||||
const asset = createMockAsset()
|
||||
const mockNode = await createMockNode()
|
||||
await setupMocks({ createdNode: mockNode })
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(true)
|
||||
expect(mockNode.widgets?.[0].value).toBe(
|
||||
'models/checkpoints/test-model.safetensors'
|
||||
)
|
||||
})
|
||||
it('should add node to root graph when no subgraph is active', async () => {
|
||||
const asset = createMockAsset()
|
||||
const mockNode = await createMockNode()
|
||||
await setupMocks({ createdNode: mockNode })
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(true)
|
||||
expect(vi.mocked(app).canvas.graph!.add).toHaveBeenCalledWith(mockNode)
|
||||
})
|
||||
it('should fallback to asset.metadata.filename when user_metadata.filename missing', async () => {
|
||||
const asset = createMockAsset({
|
||||
user_metadata: {},
|
||||
metadata: { filename: 'models/checkpoints/from-metadata.safetensors' }
|
||||
})
|
||||
const mockNode = await createMockNode()
|
||||
await setupMocks({ createdNode: mockNode })
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(true)
|
||||
expect(mockNode.widgets?.[0].value).toBe(
|
||||
'models/checkpoints/from-metadata.safetensors'
|
||||
)
|
||||
})
|
||||
it('should fallback to asset.name when both filename sources missing', async () => {
|
||||
const asset = createMockAsset({
|
||||
user_metadata: {},
|
||||
metadata: undefined
|
||||
})
|
||||
const mockNode = await createMockNode()
|
||||
await setupMocks({ createdNode: mockNode })
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(true)
|
||||
expect(mockNode.widgets?.[0].value).toBe('test-model.safetensors')
|
||||
})
|
||||
it('should add node to active subgraph when present', async () => {
|
||||
const asset = createMockAsset()
|
||||
const mockNode = await createMockNode()
|
||||
const { Subgraph } = await vi.importActual<typeof LitegraphModule>(
|
||||
'@/lib/litegraph/src/litegraph'
|
||||
)
|
||||
const mockSubgraph = markRaw(
|
||||
Object.create(Subgraph.prototype, {
|
||||
add: { value: vi.fn() }
|
||||
})
|
||||
)
|
||||
await setupMocks({
|
||||
createdNode: mockNode,
|
||||
activeSubgraph: mockSubgraph
|
||||
})
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(true)
|
||||
expect(mockSubgraph.add).toHaveBeenCalledWith(mockNode)
|
||||
expect(vi.mocked(app).canvas.graph!.add).not.toHaveBeenCalled()
|
||||
})
|
||||
it('should succeed when provider has empty key (auto-load nodes)', async () => {
|
||||
const asset = createMockAsset({
|
||||
tags: ['models', 'chatterbox/chatterbox_vc'],
|
||||
user_metadata: { filename: 'chatterbox_vc_model.pt' }
|
||||
})
|
||||
const mockNode = await createMockNode({ hasWidgets: false })
|
||||
const nodeProvider = createMockNodeProvider({
|
||||
nodeDef: {
|
||||
name: 'FL_ChatterboxVC',
|
||||
display_name: 'FL Chatterbox VC'
|
||||
},
|
||||
key: ''
|
||||
})
|
||||
await setupMocks({ createdNode: mockNode, nodeProvider })
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(true)
|
||||
expect(vi.mocked(app).canvas.graph!.add).toHaveBeenCalledWith(mockNode)
|
||||
})
|
||||
})
|
||||
describe('when asset data is incomplete or invalid', () => {
|
||||
beforeEach(() => {
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
})
|
||||
it.for([
|
||||
{
|
||||
case: 'missing user_metadata with no fallback',
|
||||
overrides: { user_metadata: undefined, metadata: undefined, name: '' },
|
||||
expectedCode: 'INVALID_ASSET' as const,
|
||||
errorPattern: /Invalid filename.*expected non-empty string/
|
||||
},
|
||||
{
|
||||
case: 'empty filename with no fallback',
|
||||
overrides: {
|
||||
user_metadata: { filename: '' },
|
||||
metadata: undefined,
|
||||
name: ''
|
||||
},
|
||||
expectedCode: 'INVALID_ASSET' as const,
|
||||
errorPattern: /Invalid filename.*expected non-empty string/
|
||||
}
|
||||
])(
|
||||
'should fail when asset has $case',
|
||||
({ overrides, expectedCode, errorPattern }) => {
|
||||
const asset = createMockAsset(overrides)
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(false)
|
||||
if (!result.success) {
|
||||
expect(result.error.code).toBe(expectedCode)
|
||||
expect(result.error.message).toMatch(errorPattern)
|
||||
expect(result.error.assetId).toBe('asset-123')
|
||||
}
|
||||
}
|
||||
)
|
||||
it.for([
|
||||
{
|
||||
case: 'no tags',
|
||||
overrides: { tags: undefined },
|
||||
expectedCode: 'INVALID_ASSET' as const,
|
||||
errorMessage: 'Asset has no tags defined'
|
||||
},
|
||||
{
|
||||
case: 'only excluded tags',
|
||||
overrides: { tags: ['models', 'missing'] },
|
||||
expectedCode: 'INVALID_ASSET' as const,
|
||||
errorMessage: 'Asset has no valid category tag'
|
||||
},
|
||||
{
|
||||
case: 'only the models tag',
|
||||
overrides: { tags: ['models'] },
|
||||
expectedCode: 'INVALID_ASSET' as const,
|
||||
errorMessage: 'Asset has no valid category tag'
|
||||
}
|
||||
])(
|
||||
'should fail when asset has $case',
|
||||
({ overrides, expectedCode, errorMessage }) => {
|
||||
const asset = createMockAsset(overrides)
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(false)
|
||||
if (!result.success) {
|
||||
expect(result.error.code).toBe(expectedCode)
|
||||
expect(result.error.message).toBe(errorMessage)
|
||||
}
|
||||
}
|
||||
)
|
||||
})
|
||||
describe('when system resources are unavailable', () => {
|
||||
beforeEach(() => {
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
})
|
||||
it('should fail when no provider registered for category', async () => {
|
||||
const asset = createMockAsset()
|
||||
await setupMocks({ nodeProvider: null })
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(false)
|
||||
if (!result.success) {
|
||||
expect(result.error.code).toBe('NO_PROVIDER')
|
||||
expect(result.error.message).toContain('checkpoints')
|
||||
expect(result.error.details?.category).toBe('checkpoints')
|
||||
}
|
||||
})
|
||||
it('should fail when node creation fails', async () => {
|
||||
const asset = createMockAsset()
|
||||
await setupMocks()
|
||||
vi.mocked(LiteGraph.createNode).mockReturnValue(null)
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(false)
|
||||
if (!result.success) {
|
||||
expect(result.error.code).toBe('NODE_CREATION_FAILED')
|
||||
expect(result.error.message).toContain('CheckpointLoaderSimple')
|
||||
}
|
||||
})
|
||||
it('should fail when widget is missing from node', async () => {
|
||||
const asset = createMockAsset()
|
||||
const mockNode = await createMockNode({ widgetName: 'wrong_widget' })
|
||||
await setupMocks({ createdNode: mockNode })
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(false)
|
||||
if (!result.success) {
|
||||
expect(result.error.code).toBe('MISSING_WIDGET')
|
||||
expect(result.error.message).toContain('ckpt_name')
|
||||
expect(result.error.message).toContain('CheckpointLoaderSimple')
|
||||
expect(result.error.details?.widgetName).toBe('ckpt_name')
|
||||
}
|
||||
})
|
||||
it('should fail when node has no widgets array', async () => {
|
||||
const asset = createMockAsset()
|
||||
const mockNode = await createMockNode({ hasWidgets: false })
|
||||
await setupMocks({ createdNode: mockNode })
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(false)
|
||||
if (!result.success) {
|
||||
expect(result.error.code).toBe('MISSING_WIDGET')
|
||||
expect(result.error.message).toContain('ckpt_name not found')
|
||||
}
|
||||
})
|
||||
it('should not add node to graph when widget validation fails', async () => {
|
||||
const asset = createMockAsset()
|
||||
const mockNode = await createMockNode({ hasWidgets: false })
|
||||
await setupMocks({ createdNode: mockNode })
|
||||
createModelNodeFromAsset(asset)
|
||||
expect(vi.mocked(app).canvas.graph!.add).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
describe('when graph is null', () => {
|
||||
beforeEach(() => {
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
vi.mocked(app).canvas.graph = null
|
||||
})
|
||||
it('should fail when no graph is available', async () => {
|
||||
const asset = createMockAsset()
|
||||
const mockNode = await createMockNode()
|
||||
await setupMocks({ createdNode: mockNode })
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(false)
|
||||
if (!result.success) {
|
||||
expect(result.error.code).toBe('NO_GRAPH')
|
||||
expect(result.error.message).toBe('No active graph available')
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,198 +0,0 @@
|
||||
import { LiteGraph } from '@/lib/litegraph/src/litegraph'
|
||||
import type { LGraphNode, Point } from '@/lib/litegraph/src/litegraph'
|
||||
import { assetItemSchema } from '@/platform/assets/schemas/assetSchema'
|
||||
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
|
||||
import {
|
||||
MISSING_TAG,
|
||||
MODELS_TAG
|
||||
} from '@/platform/assets/services/assetService'
|
||||
import { getAssetFilename } from '@/platform/assets/utils/assetMetadataUtils'
|
||||
import { useWorkflowStore } from '@/platform/workflow/management/stores/workflowStore'
|
||||
import { app } from '@/scripts/app'
|
||||
import { useLitegraphService } from '@/services/litegraphService'
|
||||
import { useModelToNodeStore } from '@/stores/modelToNodeStore'
|
||||
|
||||
interface ModelNodeCreateOptions {
|
||||
position?: Point
|
||||
}
|
||||
|
||||
type NodeCreationErrorCode =
|
||||
| 'INVALID_ASSET'
|
||||
| 'NO_PROVIDER'
|
||||
| 'NODE_CREATION_FAILED'
|
||||
| 'MISSING_WIDGET'
|
||||
| 'NO_GRAPH'
|
||||
|
||||
interface NodeCreationError {
|
||||
code: NodeCreationErrorCode
|
||||
message: string
|
||||
assetId: string
|
||||
details?: Record<string, unknown>
|
||||
}
|
||||
|
||||
type Result<T, E> = { success: true; value: T } | { success: false; error: E }
|
||||
|
||||
/**
|
||||
* Creates a LiteGraph node from an asset item.
|
||||
*
|
||||
* **Boundary Function**: Bridges Vue reactive domain with LiteGraph canvas domain.
|
||||
*
|
||||
* @param asset - Asset item to create node from (Vue domain)
|
||||
* @param options - Optional position and configuration
|
||||
* @returns Result with LiteGraph node (Canvas domain) or error details
|
||||
*
|
||||
* @remarks
|
||||
* This function performs side effects on the canvas graph. Validation failures
|
||||
* return error results rather than throwing to allow graceful degradation in UI contexts.
|
||||
* Widget validation occurs before graph mutation to prevent orphaned nodes.
|
||||
*/
|
||||
export function createModelNodeFromAsset(
|
||||
asset: AssetItem,
|
||||
options?: ModelNodeCreateOptions
|
||||
): Result<LGraphNode, NodeCreationError> {
|
||||
const validatedAsset = assetItemSchema.safeParse(asset)
|
||||
|
||||
if (!validatedAsset.success) {
|
||||
const errorMessage = validatedAsset.error.errors
|
||||
.map((e) => `${e.path.join('.')}: ${e.message}`)
|
||||
.join(', ')
|
||||
console.error('Invalid asset item:', errorMessage)
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'INVALID_ASSET',
|
||||
message: 'Asset schema validation failed',
|
||||
assetId: asset.id,
|
||||
details: { validationErrors: errorMessage }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const validAsset = validatedAsset.data
|
||||
|
||||
const filename = getAssetFilename(validAsset)
|
||||
if (filename.length === 0) {
|
||||
console.error(
|
||||
`Asset ${validAsset.id} has invalid user_metadata.filename (expected non-empty string, got ${typeof filename})`
|
||||
)
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'INVALID_ASSET',
|
||||
message: `Invalid filename (expected non-empty string, got ${typeof filename})`,
|
||||
assetId: validAsset.id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (validAsset.tags.length === 0) {
|
||||
console.error(
|
||||
`Asset ${validAsset.id} has no tags defined (expected at least one category tag)`
|
||||
)
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'INVALID_ASSET',
|
||||
message: 'Asset has no tags defined',
|
||||
assetId: validAsset.id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const category = validAsset.tags.find(
|
||||
(tag) => tag !== MODELS_TAG && tag !== MISSING_TAG
|
||||
)
|
||||
if (!category) {
|
||||
console.error(
|
||||
`Asset ${validAsset.id} has no valid category tag. Available tags: ${validAsset.tags.join(', ')} (expected tag other than '${MODELS_TAG}' or '${MISSING_TAG}')`
|
||||
)
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'INVALID_ASSET',
|
||||
message: 'Asset has no valid category tag',
|
||||
assetId: validAsset.id,
|
||||
details: { availableTags: validAsset.tags }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const modelToNodeStore = useModelToNodeStore()
|
||||
const provider = modelToNodeStore.getNodeProvider(category)
|
||||
if (!provider) {
|
||||
console.error(`No node provider registered for category: ${category}`)
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'NO_PROVIDER',
|
||||
message: `No node provider registered for category: ${category}`,
|
||||
assetId: validAsset.id,
|
||||
details: { category }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const litegraphService = useLitegraphService()
|
||||
const pos = options?.position ?? litegraphService.getCanvasCenter()
|
||||
|
||||
const node = LiteGraph.createNode(
|
||||
provider.nodeDef.name,
|
||||
provider.nodeDef.display_name,
|
||||
{ pos }
|
||||
)
|
||||
|
||||
if (!node) {
|
||||
console.error(`Failed to create node for type: ${provider.nodeDef.name}`)
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'NODE_CREATION_FAILED',
|
||||
message: `Failed to create node for type: ${provider.nodeDef.name}`,
|
||||
assetId: validAsset.id,
|
||||
details: { nodeType: provider.nodeDef.name }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const workflowStore = useWorkflowStore()
|
||||
const targetGraph = workflowStore.isSubgraphActive
|
||||
? workflowStore.activeSubgraph
|
||||
: app.canvas.graph
|
||||
|
||||
if (!targetGraph) {
|
||||
console.error('No active graph available')
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'NO_GRAPH',
|
||||
message: 'No active graph available',
|
||||
assetId: validAsset.id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set widget value if provider specifies a key (some nodes auto-load models without a widget)
|
||||
if (provider.key) {
|
||||
const widget = node.widgets?.find((w) => w.name === provider.key)
|
||||
if (!widget) {
|
||||
console.error(
|
||||
`Widget ${provider.key} not found on node ${provider.nodeDef.name}`
|
||||
)
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'MISSING_WIDGET',
|
||||
message: `Widget ${provider.key} not found on node ${provider.nodeDef.name}`,
|
||||
assetId: validAsset.id,
|
||||
details: { widgetName: provider.key, nodeType: provider.nodeDef.name }
|
||||
}
|
||||
}
|
||||
}
|
||||
widget.value = filename
|
||||
}
|
||||
|
||||
// Add the node to the graph
|
||||
targetGraph.add(node)
|
||||
|
||||
return { success: true, value: node }
|
||||
}
|
||||
208
src/platform/assets/utils/resolveModelNodeFromAsset.test.ts
Normal file
@@ -0,0 +1,208 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
|
||||
import { resolveModelNodeFromAsset } from '@/platform/assets/utils/resolveModelNodeFromAsset'
|
||||
|
||||
const mockGetNodeProvider = vi.hoisted(() => vi.fn())
|
||||
|
||||
vi.mock('@/stores/modelToNodeStore', () => ({
|
||||
useModelToNodeStore: () => ({ getNodeProvider: mockGetNodeProvider })
|
||||
}))
|
||||
|
||||
function createMockAsset(overrides: Partial<AssetItem> = {}): AssetItem {
|
||||
return {
|
||||
id: 'asset-123',
|
||||
name: 'test-model.safetensors',
|
||||
size: 1024,
|
||||
created_at: '2025-10-01T00:00:00Z',
|
||||
tags: ['models', 'checkpoints'],
|
||||
user_metadata: {
|
||||
filename: 'models/checkpoints/test-model.safetensors'
|
||||
},
|
||||
...overrides
|
||||
}
|
||||
}
|
||||
|
||||
function createMockNodeProvider(
|
||||
overrides: {
|
||||
nodeDef?: { name: string; display_name: string }
|
||||
key?: string
|
||||
} = {}
|
||||
) {
|
||||
return {
|
||||
nodeDef: {
|
||||
name: 'CheckpointLoaderSimple',
|
||||
display_name: 'Load Checkpoint',
|
||||
...overrides.nodeDef
|
||||
},
|
||||
key: overrides.key ?? 'ckpt_name'
|
||||
}
|
||||
}
|
||||
|
||||
function mockProvider(
|
||||
provider: ReturnType<typeof createMockNodeProvider> | null
|
||||
) {
|
||||
mockGetNodeProvider.mockReturnValue(provider)
|
||||
}
|
||||
|
||||
describe('resolveModelNodeFromAsset', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
})
|
||||
|
||||
describe('valid assets', () => {
|
||||
it('resolves the provider for the asset category and the filename', () => {
|
||||
mockProvider(createMockNodeProvider())
|
||||
const result = resolveModelNodeFromAsset(createMockAsset())
|
||||
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(mockGetNodeProvider).toHaveBeenCalledWith('checkpoints')
|
||||
expect(result.value.provider.nodeDef.name).toBe(
|
||||
'CheckpointLoaderSimple'
|
||||
)
|
||||
expect(result.value.filename).toBe(
|
||||
'models/checkpoints/test-model.safetensors'
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
it('falls back to metadata.filename when user_metadata.filename missing', () => {
|
||||
mockProvider(createMockNodeProvider())
|
||||
const result = resolveModelNodeFromAsset(
|
||||
createMockAsset({
|
||||
user_metadata: {},
|
||||
metadata: { filename: 'models/checkpoints/from-metadata.safetensors' }
|
||||
})
|
||||
)
|
||||
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.value.filename).toBe(
|
||||
'models/checkpoints/from-metadata.safetensors'
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
it('falls back to asset.name when both filename sources missing', () => {
|
||||
mockProvider(createMockNodeProvider())
|
||||
const result = resolveModelNodeFromAsset(
|
||||
createMockAsset({ user_metadata: {}, metadata: undefined })
|
||||
)
|
||||
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.value.filename).toBe('test-model.safetensors')
|
||||
}
|
||||
})
|
||||
|
||||
it('resolves a provider with an empty key (auto-load nodes)', () => {
|
||||
mockProvider(
|
||||
createMockNodeProvider({
|
||||
nodeDef: {
|
||||
name: 'FL_ChatterboxVC',
|
||||
display_name: 'FL Chatterbox VC'
|
||||
},
|
||||
key: ''
|
||||
})
|
||||
)
|
||||
const result = resolveModelNodeFromAsset(
|
||||
createMockAsset({
|
||||
tags: ['models', 'chatterbox/chatterbox_vc'],
|
||||
user_metadata: { filename: 'chatterbox_vc_model.pt' }
|
||||
})
|
||||
)
|
||||
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.value.provider.key).toBe('')
|
||||
expect(result.value.filename).toBe('chatterbox_vc_model.pt')
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe('invalid assets', () => {
|
||||
it('fails when the asset does not match the schema', () => {
|
||||
const invalid = {
|
||||
id: 'asset-123',
|
||||
tags: ['models', 'checkpoints']
|
||||
} as unknown as AssetItem
|
||||
|
||||
const result = resolveModelNodeFromAsset(invalid)
|
||||
|
||||
expect(result.success).toBe(false)
|
||||
if (!result.success) {
|
||||
expect(result.error.code).toBe('INVALID_ASSET')
|
||||
expect(result.error.message).toBe('Asset schema validation failed')
|
||||
expect(result.error.assetId).toBe('asset-123')
|
||||
expect(result.error.details?.validationErrors).toBeTruthy()
|
||||
}
|
||||
})
|
||||
|
||||
it.for([
|
||||
{
|
||||
case: 'missing user_metadata with no fallback',
|
||||
overrides: {
|
||||
user_metadata: undefined,
|
||||
metadata: undefined,
|
||||
name: ''
|
||||
},
|
||||
errorPattern: /Invalid filename.*expected non-empty string/
|
||||
},
|
||||
{
|
||||
case: 'empty filename with no fallback',
|
||||
overrides: {
|
||||
user_metadata: { filename: '' },
|
||||
metadata: undefined,
|
||||
name: ''
|
||||
},
|
||||
errorPattern: /Invalid filename.*expected non-empty string/
|
||||
}
|
||||
])('fails when asset has $case', ({ overrides, errorPattern }) => {
|
||||
const result = resolveModelNodeFromAsset(createMockAsset(overrides))
|
||||
expect(result.success).toBe(false)
|
||||
if (!result.success) {
|
||||
expect(result.error.code).toBe('INVALID_ASSET')
|
||||
expect(result.error.message).toMatch(errorPattern)
|
||||
expect(result.error.assetId).toBe('asset-123')
|
||||
}
|
||||
})
|
||||
|
||||
it.for([
|
||||
{
|
||||
case: 'no tags',
|
||||
overrides: { tags: undefined },
|
||||
message: 'Asset has no tags defined'
|
||||
},
|
||||
{
|
||||
case: 'only excluded tags',
|
||||
overrides: { tags: ['models', 'missing'] },
|
||||
message: 'Asset has no valid category tag'
|
||||
},
|
||||
{
|
||||
case: 'only the models tag',
|
||||
overrides: { tags: ['models'] },
|
||||
message: 'Asset has no valid category tag'
|
||||
}
|
||||
])('fails when asset has $case', ({ overrides, message }) => {
|
||||
const result = resolveModelNodeFromAsset(createMockAsset(overrides))
|
||||
expect(result.success).toBe(false)
|
||||
if (!result.success) {
|
||||
expect(result.error.code).toBe('INVALID_ASSET')
|
||||
expect(result.error.message).toBe(message)
|
||||
}
|
||||
})
|
||||
|
||||
it('fails when no provider is registered for the category', () => {
|
||||
mockProvider(null)
|
||||
const result = resolveModelNodeFromAsset(createMockAsset())
|
||||
expect(result.success).toBe(false)
|
||||
if (!result.success) {
|
||||
expect(result.error.code).toBe('NO_PROVIDER')
|
||||
expect(result.error.message).toContain('checkpoints')
|
||||
expect(result.error.details?.category).toBe('checkpoints')
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
117
src/platform/assets/utils/resolveModelNodeFromAsset.ts
Normal file
@@ -0,0 +1,117 @@
|
||||
import { assetItemSchema } from '@/platform/assets/schemas/assetSchema'
|
||||
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
|
||||
import {
|
||||
MISSING_TAG,
|
||||
MODELS_TAG
|
||||
} from '@/platform/assets/services/assetService'
|
||||
import { getAssetFilename } from '@/platform/assets/utils/assetMetadataUtils'
|
||||
import { useModelToNodeStore } from '@/stores/modelToNodeStore'
|
||||
import type { ModelNodeProvider } from '@/stores/modelToNodeStore'
|
||||
|
||||
type ResolveErrorCode = 'INVALID_ASSET' | 'NO_PROVIDER'
|
||||
|
||||
export interface ResolveModelNodeError {
|
||||
code: ResolveErrorCode
|
||||
message: string
|
||||
assetId: string
|
||||
details?: Record<string, unknown>
|
||||
}
|
||||
|
||||
export interface ResolvedModelNode {
|
||||
provider: ModelNodeProvider
|
||||
filename: string
|
||||
}
|
||||
|
||||
type Result<T, E> = { success: true; value: T } | { success: false; error: E }
|
||||
|
||||
/**
|
||||
* Resolves an asset item to the node provider and filename needed to add a
|
||||
* model loader node. Validation failures return error results rather than
|
||||
* throwing, so callers can degrade gracefully in UI contexts.
|
||||
*/
|
||||
export function resolveModelNodeFromAsset(
|
||||
asset: AssetItem
|
||||
): Result<ResolvedModelNode, ResolveModelNodeError> {
|
||||
const validatedAsset = assetItemSchema.safeParse(asset)
|
||||
|
||||
if (!validatedAsset.success) {
|
||||
const errorMessage = validatedAsset.error.errors
|
||||
.map((e) => `${e.path.join('.')}: ${e.message}`)
|
||||
.join(', ')
|
||||
console.error('Invalid asset item:', errorMessage)
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'INVALID_ASSET',
|
||||
message: 'Asset schema validation failed',
|
||||
assetId: typeof asset?.id === 'string' ? asset.id : 'unknown',
|
||||
details: { validationErrors: errorMessage }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const validAsset = validatedAsset.data
|
||||
|
||||
const filename = getAssetFilename(validAsset)
|
||||
if (filename.length === 0) {
|
||||
console.error(
|
||||
`Asset ${validAsset.id} has invalid user_metadata.filename (expected non-empty string, got ${typeof filename})`
|
||||
)
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'INVALID_ASSET',
|
||||
message: `Invalid filename (expected non-empty string, got ${typeof filename})`,
|
||||
assetId: validAsset.id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (validAsset.tags.length === 0) {
|
||||
console.error(
|
||||
`Asset ${validAsset.id} has no tags defined (expected at least one category tag)`
|
||||
)
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'INVALID_ASSET',
|
||||
message: 'Asset has no tags defined',
|
||||
assetId: validAsset.id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const category = validAsset.tags.find(
|
||||
(tag) => tag !== MODELS_TAG && tag !== MISSING_TAG
|
||||
)
|
||||
if (!category) {
|
||||
console.error(
|
||||
`Asset ${validAsset.id} has no valid category tag. Available tags: ${validAsset.tags.join(', ')} (expected tag other than '${MODELS_TAG}' or '${MISSING_TAG}')`
|
||||
)
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'INVALID_ASSET',
|
||||
message: 'Asset has no valid category tag',
|
||||
assetId: validAsset.id,
|
||||
details: { availableTags: validAsset.tags }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const provider = useModelToNodeStore().getNodeProvider(category)
|
||||
if (!provider) {
|
||||
console.error(`No node provider registered for category: ${category}`)
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'NO_PROVIDER',
|
||||
message: `No node provider registered for category: ${category}`,
|
||||
assetId: validAsset.id,
|
||||
details: { category }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return { success: true, value: { provider, filename } }
|
||||
}
|
||||
@@ -1,36 +1,24 @@
|
||||
import { createTestingPinia } from '@pinia/testing'
|
||||
import { render, screen, within } from '@testing-library/vue'
|
||||
import { render, screen } from '@testing-library/vue'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import PrimeVue from 'primevue/config'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { nextTick } from 'vue'
|
||||
import { createI18n } from 'vue-i18n'
|
||||
|
||||
import enMessages from '@/locales/en/main.json' with { type: 'json' }
|
||||
import type {
|
||||
MissingModelGroup,
|
||||
MissingModelViewModel
|
||||
} from '@/platform/missingModel/types'
|
||||
import { useMissingModelStore } from '@/platform/missingModel/missingModelStore'
|
||||
|
||||
vi.mock('./MissingModelRow.vue', () => ({
|
||||
default: {
|
||||
name: 'MissingModelRow',
|
||||
template: `
|
||||
<div
|
||||
data-testid="model-row"
|
||||
class="model-row"
|
||||
:data-model-name="model.name"
|
||||
:data-is-asset-supported="isAssetSupported"
|
||||
:data-directory="directory"
|
||||
>
|
||||
<button
|
||||
class="locate-trigger"
|
||||
@click="$emit('locate-model', model?.representative?.nodeId)"
|
||||
>
|
||||
Locate
|
||||
</button>
|
||||
</div>
|
||||
`,
|
||||
props: ['model', 'directory', 'isAssetSupported'],
|
||||
template:
|
||||
'<div class="model-row" :data-show-node-id-badge="showNodeIdBadge" :data-is-asset-supported="isAssetSupported" :data-directory="directory"><button class="locate-trigger" @click="$emit(\'locate-model\', model?.representative?.nodeId)">Locate</button></div>',
|
||||
props: ['model', 'directory', 'showNodeIdBadge', 'isAssetSupported'],
|
||||
emits: ['locate-model']
|
||||
}
|
||||
}))
|
||||
@@ -47,7 +35,21 @@ import MissingModelCard from './MissingModelCard.vue'
|
||||
const i18n = createI18n({
|
||||
legacy: false,
|
||||
locale: 'en',
|
||||
messages: { en: enMessages },
|
||||
messages: {
|
||||
en: {
|
||||
rightSidePanel: {
|
||||
missingModels: {
|
||||
importNotSupported: 'Import Not Supported',
|
||||
customNodeDownloadDisabled:
|
||||
'Cloud environment does not support model imports for custom nodes.',
|
||||
unknownCategory: 'Unknown Category',
|
||||
downloadAll: 'Download all',
|
||||
refresh: 'Refresh',
|
||||
refreshing: 'Refreshing missing models.'
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
missingWarn: false,
|
||||
fallbackWarn: false
|
||||
})
|
||||
@@ -104,6 +106,7 @@ function makeGroup(
|
||||
function mountCard(
|
||||
props: Partial<{
|
||||
missingModelGroups: MissingModelGroup[]
|
||||
showNodeIdBadge: boolean
|
||||
}> = {},
|
||||
onLocateModel?: (nodeId: string) => void
|
||||
) {
|
||||
@@ -111,6 +114,7 @@ function mountCard(
|
||||
return render(MissingModelCard, {
|
||||
props: {
|
||||
missingModelGroups: [makeGroup()],
|
||||
showNodeIdBadge: false,
|
||||
...props,
|
||||
...(onLocateModel ? { onLocateModel } : {})
|
||||
},
|
||||
@@ -120,70 +124,62 @@ function mountCard(
|
||||
})
|
||||
}
|
||||
|
||||
function getRows() {
|
||||
return screen.queryAllByTestId('model-row')
|
||||
}
|
||||
|
||||
describe('MissingModelCard', () => {
|
||||
beforeEach(() => {
|
||||
mockIsCloud.value = true
|
||||
})
|
||||
|
||||
describe('Rendering & Props', () => {
|
||||
it('passes the model directory to rows', () => {
|
||||
mockIsCloud.value = false
|
||||
mountCard({
|
||||
it('renders directory name in category header', () => {
|
||||
const { container } = mountCard({
|
||||
missingModelGroups: [makeGroup({ directory: 'loras' })]
|
||||
})
|
||||
expect(getRows()[0].getAttribute('data-directory')).toBe('loras')
|
||||
expect(container.textContent).toContain('loras')
|
||||
})
|
||||
|
||||
it('renders translated unknown category when directory is null', () => {
|
||||
const { container } = mountCard({
|
||||
missingModelGroups: [makeGroup({ directory: null })]
|
||||
})
|
||||
expect(container.textContent).toContain('Unknown Category')
|
||||
})
|
||||
|
||||
it('renders model count in category header', () => {
|
||||
const { container } = mountCard({
|
||||
missingModelGroups: [
|
||||
makeGroup({ modelNames: ['a.safetensors', 'b.safetensors'] })
|
||||
]
|
||||
})
|
||||
expect(container.textContent).toContain('(2)')
|
||||
})
|
||||
|
||||
it('renders correct number of MissingModelRow components', () => {
|
||||
mountCard({
|
||||
const { container } = mountCard({
|
||||
missingModelGroups: [
|
||||
makeGroup({
|
||||
modelNames: ['a.safetensors', 'b.safetensors', 'c.safetensors']
|
||||
})
|
||||
]
|
||||
})
|
||||
expect(getRows()).toHaveLength(3)
|
||||
// eslint-disable-next-line testing-library/no-container, testing-library/no-node-access
|
||||
expect(container.querySelectorAll('.model-row')).toHaveLength(3)
|
||||
})
|
||||
|
||||
it('flattens multiple groups into rows', () => {
|
||||
mockIsCloud.value = false
|
||||
mountCard({
|
||||
it('renders multiple groups', () => {
|
||||
const { container } = mountCard({
|
||||
missingModelGroups: [
|
||||
makeGroup({ directory: 'checkpoints' }),
|
||||
makeGroup({ directory: 'loras' })
|
||||
]
|
||||
})
|
||||
expect(getRows()).toHaveLength(2)
|
||||
})
|
||||
|
||||
it('sorts rows by model type order in cloud', () => {
|
||||
mountCard({
|
||||
missingModelGroups: [
|
||||
makeGroup({ directory: null, modelNames: ['unknown.safetensors'] }),
|
||||
makeGroup({ directory: 'loras', modelNames: ['lora.safetensors'] }),
|
||||
makeGroup({
|
||||
directory: 'checkpoints',
|
||||
modelNames: ['checkpoint.safetensors']
|
||||
})
|
||||
]
|
||||
})
|
||||
|
||||
expect(
|
||||
getRows().map((row) => row.getAttribute('data-model-name'))
|
||||
).toEqual([
|
||||
'checkpoint.safetensors',
|
||||
'lora.safetensors',
|
||||
'unknown.safetensors'
|
||||
])
|
||||
expect(container.textContent).toContain('checkpoints')
|
||||
expect(container.textContent).toContain('loras')
|
||||
})
|
||||
|
||||
it('renders zero rows when missingModelGroups is empty', () => {
|
||||
mountCard({ missingModelGroups: [] })
|
||||
expect(getRows()).toHaveLength(0)
|
||||
const { container } = mountCard({ missingModelGroups: [] })
|
||||
// eslint-disable-next-line testing-library/no-container, testing-library/no-node-access
|
||||
expect(container.querySelectorAll('.model-row')).toHaveLength(0)
|
||||
})
|
||||
|
||||
it('hides bulk actions in cloud', () => {
|
||||
@@ -195,21 +191,31 @@ describe('MissingModelCard', () => {
|
||||
screen.queryByTestId('missing-model-actions')
|
||||
).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('passes props correctly to MissingModelRow children', () => {
|
||||
const { container } = mountCard({ showNodeIdBadge: true })
|
||||
// eslint-disable-next-line testing-library/no-container, testing-library/no-node-access
|
||||
const row = container.querySelector('.model-row')
|
||||
expect(row).not.toBeNull()
|
||||
expect(row!.getAttribute('data-show-node-id-badge')).toBe('true')
|
||||
expect(row!.getAttribute('data-is-asset-supported')).toBe('true')
|
||||
expect(row!.getAttribute('data-directory')).toBe('checkpoints')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Asset Unsupported Group', () => {
|
||||
it('does not show the unsupported group header in cloud', () => {
|
||||
it('shows "Import Not Supported" header for unsupported groups', () => {
|
||||
const { container } = mountCard({
|
||||
missingModelGroups: [makeGroup({ isAssetSupported: false })]
|
||||
})
|
||||
expect(container.textContent).not.toContain('Import Not Supported')
|
||||
expect(container.textContent).toContain('Import Not Supported')
|
||||
})
|
||||
|
||||
it('does not show the unsupported group notice in cloud', () => {
|
||||
it('shows info notice for unsupported groups', () => {
|
||||
const { container } = mountCard({
|
||||
missingModelGroups: [makeGroup({ isAssetSupported: false })]
|
||||
})
|
||||
expect(container.textContent).not.toContain(
|
||||
expect(container.textContent).toContain(
|
||||
'Cloud environment does not support model imports'
|
||||
)
|
||||
})
|
||||
@@ -245,12 +251,13 @@ describe('MissingModelCard (OSS)', () => {
|
||||
})
|
||||
|
||||
it('shows directory name instead of "Import Not Supported" for unsupported groups', () => {
|
||||
mountCard({
|
||||
const { container } = mountCard({
|
||||
missingModelGroups: [
|
||||
makeGroup({ directory: 'checkpoints', isAssetSupported: false })
|
||||
]
|
||||
})
|
||||
expect(getRows()[0].getAttribute('data-directory')).toBe('checkpoints')
|
||||
expect(container.textContent).toContain('checkpoints')
|
||||
expect(container.textContent).not.toContain('Import Not Supported')
|
||||
})
|
||||
|
||||
it('hides info notice for unsupported groups', () => {
|
||||
@@ -262,39 +269,61 @@ describe('MissingModelCard (OSS)', () => {
|
||||
)
|
||||
})
|
||||
|
||||
it('passes null directory for unknown category rows in OSS', () => {
|
||||
it('renders unknown category for null directory in OSS', () => {
|
||||
const { container } = mountCard({
|
||||
missingModelGroups: [
|
||||
makeGroup({ directory: null, isAssetSupported: false })
|
||||
]
|
||||
})
|
||||
expect(getRows()[0].hasAttribute('data-directory')).toBe(false)
|
||||
expect(container.textContent).toContain('Unknown Category')
|
||||
expect(container.textContent).not.toContain('Import Not Supported')
|
||||
})
|
||||
|
||||
it('shows Download all at the bottom when one model is downloadable', () => {
|
||||
it('shows bulk actions when one model is downloadable', () => {
|
||||
mountCard({
|
||||
missingModelGroups: [makeGroup({ withDownloadUrls: true })]
|
||||
})
|
||||
|
||||
const actions = screen.getByTestId('missing-model-actions')
|
||||
expect(actions).toBeVisible()
|
||||
expect(
|
||||
within(actions).getByRole('button', { name: /Download all/ })
|
||||
).toBeVisible()
|
||||
expect(
|
||||
within(actions).queryByRole('button', { name: 'Refresh' })
|
||||
).not.toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: /Download all/ })).toBeVisible()
|
||||
expect(screen.getByRole('button', { name: 'Refresh' })).toBeVisible()
|
||||
})
|
||||
|
||||
it('hides Download all when no model is downloadable', () => {
|
||||
it('hides bulk actions when no model is downloadable', () => {
|
||||
mountCard()
|
||||
|
||||
expect(
|
||||
screen.queryByRole('button', { name: /Download all/ })
|
||||
).not.toBeInTheDocument()
|
||||
expect(
|
||||
screen.queryByTestId('missing-model-actions')
|
||||
screen.queryByRole('button', { name: 'Refresh' })
|
||||
).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('refreshes missing models from the action bar', async () => {
|
||||
mountCard({
|
||||
missingModelGroups: [makeGroup({ withDownloadUrls: true })]
|
||||
})
|
||||
const store = useMissingModelStore()
|
||||
|
||||
await userEvent.click(screen.getByRole('button', { name: 'Refresh' }))
|
||||
|
||||
expect(store.refreshMissingModels).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('keeps the Refresh button focusable and announces refresh progress', async () => {
|
||||
mountCard({
|
||||
missingModelGroups: [makeGroup({ withDownloadUrls: true })]
|
||||
})
|
||||
const store = useMissingModelStore()
|
||||
|
||||
store.isRefreshingMissingModels = true
|
||||
await nextTick()
|
||||
|
||||
const refreshButton = screen.getByRole('button', { name: 'Refresh' })
|
||||
expect(refreshButton).toHaveAttribute('aria-disabled', 'true')
|
||||
expect(refreshButton).toHaveAttribute('aria-busy', 'true')
|
||||
expect(screen.getByRole('status')).toHaveTextContent(
|
||||
'Refreshing missing models.'
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,20 +1,9 @@
|
||||
<template>
|
||||
<div class="px-4 pb-2">
|
||||
<div class="flex flex-col gap-1 overflow-hidden py-2">
|
||||
<MissingModelRow
|
||||
v-for="row in sortedModelRows"
|
||||
:key="row.key"
|
||||
:model="row.model"
|
||||
:directory="row.directory"
|
||||
:is-asset-supported="row.isAssetSupported"
|
||||
@locate-model="emit('locateModel', $event)"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div
|
||||
v-if="downloadableModels.length > 0"
|
||||
data-testid="missing-model-actions"
|
||||
class="flex items-center pt-2"
|
||||
class="flex items-center gap-2 border-b border-interface-stroke py-2"
|
||||
>
|
||||
<Button
|
||||
data-testid="missing-model-download-all"
|
||||
@@ -26,6 +15,100 @@
|
||||
<i aria-hidden="true" class="icon-[lucide--download] size-4 shrink-0" />
|
||||
<span class="truncate">{{ downloadAllLabel }}</span>
|
||||
</Button>
|
||||
<!-- Keep this focusable while refreshing so the live status remains discoverable. -->
|
||||
<Button
|
||||
data-testid="missing-model-refresh"
|
||||
variant="secondary"
|
||||
size="sm"
|
||||
class="h-8 w-28 shrink-0 rounded-lg text-sm"
|
||||
:aria-busy="missingModelStore.isRefreshingMissingModels"
|
||||
:aria-disabled="missingModelStore.isRefreshingMissingModels"
|
||||
@click="handleRefreshClick"
|
||||
>
|
||||
<DotSpinner
|
||||
v-if="missingModelStore.isRefreshingMissingModels"
|
||||
aria-hidden="true"
|
||||
duration="1s"
|
||||
:size="12"
|
||||
/>
|
||||
<i
|
||||
v-else
|
||||
aria-hidden="true"
|
||||
class="icon-[lucide--refresh-cw] size-4 shrink-0"
|
||||
/>
|
||||
{{ t('rightSidePanel.missingModels.refresh') }}
|
||||
</Button>
|
||||
<span role="status" aria-live="polite" class="sr-only">
|
||||
{{
|
||||
missingModelStore.isRefreshingMissingModels
|
||||
? t('rightSidePanel.missingModels.refreshing')
|
||||
: ''
|
||||
}}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- Category groups (by directory) -->
|
||||
<div
|
||||
v-for="group in missingModelGroups"
|
||||
:key="`${group.isAssetSupported ? 'supported' : 'unsupported'}::${group.directory ?? '__unknown__'}`"
|
||||
class="flex w-full flex-col border-t border-interface-stroke py-2 first:border-t-0 first:pt-0"
|
||||
>
|
||||
<!-- Category header -->
|
||||
<div class="flex h-8 w-full items-center">
|
||||
<p
|
||||
class="min-w-0 flex-1 truncate text-sm font-medium"
|
||||
:class="
|
||||
(isCloud && !group.isAssetSupported) || group.directory === null
|
||||
? 'text-warning-background'
|
||||
: 'text-destructive-background-hover'
|
||||
"
|
||||
>
|
||||
<span v-if="isCloud && !group.isAssetSupported">
|
||||
{{ t('rightSidePanel.missingModels.importNotSupported') }}
|
||||
({{ group.models.length }})
|
||||
</span>
|
||||
<span v-else>
|
||||
<i
|
||||
v-if="group.directory === null"
|
||||
aria-hidden="true"
|
||||
class="mr-1 icon-[lucide--triangle-alert] size-3.5 align-text-bottom"
|
||||
/>
|
||||
{{
|
||||
group.directory ??
|
||||
t('rightSidePanel.missingModels.unknownCategory')
|
||||
}}
|
||||
({{ group.models.length }})
|
||||
</span>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- Asset unsupported group notice -->
|
||||
<div
|
||||
v-if="isCloud && !group.isAssetSupported"
|
||||
data-testid="missing-model-import-unsupported"
|
||||
class="flex items-start gap-1.5 px-0.5 py-1 pl-2"
|
||||
>
|
||||
<i
|
||||
aria-hidden="true"
|
||||
class="mt-0.5 icon-[lucide--info] size-3.5 shrink-0 text-muted-foreground"
|
||||
/>
|
||||
<span class="text-xs/tight text-muted-foreground">
|
||||
{{ t('rightSidePanel.missingModels.customNodeDownloadDisabled') }}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- Model rows -->
|
||||
<div class="flex flex-col gap-1 overflow-hidden pl-2">
|
||||
<MissingModelRow
|
||||
v-for="model in group.models"
|
||||
:key="model.name"
|
||||
:model="model"
|
||||
:directory="group.directory"
|
||||
:show-node-id-badge="showNodeIdBadge"
|
||||
:is-asset-supported="group.isAssetSupported"
|
||||
@locate-model="emit('locateModel', $event)"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
@@ -37,28 +120,15 @@ import type { MissingModelGroup } from '@/platform/missingModel/types'
|
||||
import { isCloud } from '@/platform/distribution/types'
|
||||
import MissingModelRow from '@/platform/missingModel/components/MissingModelRow.vue'
|
||||
import Button from '@/components/ui/button/Button.vue'
|
||||
import DotSpinner from '@/components/common/DotSpinner.vue'
|
||||
import { downloadModel } from '@/platform/missingModel/missingModelDownload'
|
||||
import { getDownloadableModels } from '@/platform/missingModel/missingModelViewUtils'
|
||||
import { useMissingModelStore } from '@/platform/missingModel/missingModelStore'
|
||||
import { formatSize } from '@/utils/formatUtil'
|
||||
|
||||
interface MissingModelRowEntry {
|
||||
key: string
|
||||
model: MissingModelGroup['models'][number]
|
||||
directory: string | null
|
||||
isAssetSupported: boolean
|
||||
}
|
||||
|
||||
const MODEL_TYPE_SORT_ORDER = [
|
||||
'checkpoints',
|
||||
'loras',
|
||||
'vae',
|
||||
'text_encoders',
|
||||
'diffusion_models'
|
||||
] as const
|
||||
|
||||
const { missingModelGroups } = defineProps<{
|
||||
const { missingModelGroups, showNodeIdBadge } = defineProps<{
|
||||
missingModelGroups: MissingModelGroup[]
|
||||
showNodeIdBadge: boolean
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
@@ -68,19 +138,6 @@ const emit = defineEmits<{
|
||||
const { t } = useI18n()
|
||||
const missingModelStore = useMissingModelStore()
|
||||
|
||||
const sortedModelRows = computed(() =>
|
||||
missingModelGroups
|
||||
.flatMap((group) =>
|
||||
group.models.map((model, index) => ({
|
||||
key: getModelRowKey(group, model, index),
|
||||
model,
|
||||
directory: group.directory,
|
||||
isAssetSupported: group.isAssetSupported
|
||||
}))
|
||||
)
|
||||
.sort((a, b) => compareModelRows(a, b))
|
||||
)
|
||||
|
||||
const downloadableModels = computed(() => {
|
||||
if (isCloud) return []
|
||||
|
||||
@@ -102,33 +159,7 @@ function downloadAllModels() {
|
||||
}
|
||||
}
|
||||
|
||||
function getModelRowKey(
|
||||
group: MissingModelGroup,
|
||||
model: MissingModelGroup['models'][number],
|
||||
index: number
|
||||
) {
|
||||
const supportKey = group.isAssetSupported ? 'supported' : 'unsupported'
|
||||
return [
|
||||
supportKey,
|
||||
group.directory ?? '__unknown__',
|
||||
model.name,
|
||||
String(index)
|
||||
].join('::')
|
||||
}
|
||||
|
||||
function compareModelRows(a: MissingModelRowEntry, b: MissingModelRowEntry) {
|
||||
return (
|
||||
getModelTypeSortIndex(a.directory) - getModelTypeSortIndex(b.directory) ||
|
||||
(a.directory ?? '').localeCompare(b.directory ?? '') ||
|
||||
a.model.name.localeCompare(b.model.name)
|
||||
)
|
||||
}
|
||||
|
||||
function getModelTypeSortIndex(directory: string | null) {
|
||||
if (directory === null) return Number.MAX_SAFE_INTEGER
|
||||
const index = MODEL_TYPE_SORT_ORDER.indexOf(
|
||||
directory as (typeof MODEL_TYPE_SORT_ORDER)[number]
|
||||
)
|
||||
return index === -1 ? MODEL_TYPE_SORT_ORDER.length : index
|
||||
function handleRefreshClick() {
|
||||
void missingModelStore.refreshMissingModels()
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -0,0 +1,113 @@
|
||||
<template>
|
||||
<div class="flex flex-col gap-2">
|
||||
<div v-if="showDivider" class="flex items-center justify-center py-0.5">
|
||||
<span class="text-xs font-bold text-muted-foreground">
|
||||
{{ t('rightSidePanel.missingModels.or') }}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<Select
|
||||
:model-value="modelValue"
|
||||
:disabled="options.length === 0"
|
||||
@update:model-value="handleSelect"
|
||||
>
|
||||
<SelectTrigger
|
||||
size="md"
|
||||
class="border-transparent bg-secondary-background text-xs hover:border-interface-stroke"
|
||||
>
|
||||
<SelectValue
|
||||
:placeholder="t('rightSidePanel.missingModels.useFromLibrary')"
|
||||
/>
|
||||
</SelectTrigger>
|
||||
|
||||
<SelectContent>
|
||||
<template v-if="options.length > 4" #prepend>
|
||||
<div class="px-1 pb-1.5">
|
||||
<div
|
||||
class="flex items-center gap-1.5 rounded-md border border-border-default px-2"
|
||||
>
|
||||
<i
|
||||
aria-hidden="true"
|
||||
class="icon-[lucide--search] size-3.5 shrink-0 text-muted-foreground"
|
||||
/>
|
||||
<input
|
||||
v-model="filterQuery"
|
||||
type="text"
|
||||
:aria-label="t('g.searchPlaceholder', { subject: '' })"
|
||||
class="h-7 w-full border-none bg-transparent text-xs outline-none placeholder:text-muted-foreground"
|
||||
:placeholder="t('g.searchPlaceholder', { subject: '' })"
|
||||
@keydown.stop
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<SelectItem
|
||||
v-for="option in filteredOptions"
|
||||
:key="option.value"
|
||||
:value="option.value"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ option.name }}
|
||||
</SelectItem>
|
||||
<div
|
||||
v-if="filteredOptions.length === 0"
|
||||
role="status"
|
||||
class="px-3 py-2 text-xs text-muted-foreground"
|
||||
>
|
||||
{{ t('g.noResultsFound') }}
|
||||
</div>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed, ref, watch } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import { useFuse } from '@vueuse/integrations/useFuse'
|
||||
import Select from '@/components/ui/select/Select.vue'
|
||||
import SelectContent from '@/components/ui/select/SelectContent.vue'
|
||||
import SelectTrigger from '@/components/ui/select/SelectTrigger.vue'
|
||||
import SelectValue from '@/components/ui/select/SelectValue.vue'
|
||||
import SelectItem from '@/components/ui/select/SelectItem.vue'
|
||||
|
||||
const { options, showDivider = false } = defineProps<{
|
||||
modelValue: string | undefined
|
||||
options: { name: string; value: string }[]
|
||||
showDivider?: boolean
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
select: [value: string]
|
||||
}>()
|
||||
|
||||
const { t } = useI18n()
|
||||
|
||||
const filterQuery = ref('')
|
||||
|
||||
watch(
|
||||
() => options.length,
|
||||
(len) => {
|
||||
if (len <= 4) filterQuery.value = ''
|
||||
}
|
||||
)
|
||||
|
||||
const { results: fuseResults } = useFuse(filterQuery, () => options, {
|
||||
fuseOptions: {
|
||||
keys: ['name'],
|
||||
threshold: 0.4,
|
||||
ignoreLocation: true
|
||||
},
|
||||
matchAllWhenSearchEmpty: true
|
||||
})
|
||||
|
||||
const filteredOptions = computed(() => fuseResults.value.map((r) => r.item))
|
||||
|
||||
function handleSelect(value: unknown) {
|
||||
if (typeof value === 'string') {
|
||||
filterQuery.value = ''
|
||||
emit('select', value)
|
||||
}
|
||||
}
|
||||
</script>
|
||||
@@ -1,442 +0,0 @@
|
||||
import { createPinia, setActivePinia } from 'pinia'
|
||||
import { render, screen, waitFor } from '@testing-library/vue'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { nextTick } from 'vue'
|
||||
import { createI18n } from 'vue-i18n'
|
||||
|
||||
import enMessages from '@/locales/en/main.json' with { type: 'json' }
|
||||
import type {
|
||||
UploadModelDialogContext,
|
||||
UploadModelSuccess
|
||||
} from '@/platform/assets/composables/useUploadModelWizard'
|
||||
import type { MissingModelViewModel } from '@/platform/missingModel/types'
|
||||
import type * as MissingModelDownload from '@/platform/missingModel/missingModelDownload'
|
||||
import type * as GraphTraversalUtil from '@/utils/graphTraversalUtil'
|
||||
import { useMissingModelStore } from '@/platform/missingModel/missingModelStore'
|
||||
|
||||
const mockIsCloud = vi.hoisted(() => ({ value: true }))
|
||||
const mockShowUploadDialog = vi.hoisted(() => vi.fn())
|
||||
const mockCopyToClipboard = vi.hoisted(() => vi.fn())
|
||||
const mockDownloadModel = vi.hoisted(() => vi.fn())
|
||||
const mockRootGraph = vi.hoisted<{
|
||||
value: Record<string, never> | null
|
||||
}>(() => ({ value: null }))
|
||||
const mockGetNodeByExecutionId = vi.hoisted(() => vi.fn())
|
||||
const mockApiListeners = vi.hoisted(
|
||||
() => new Map<string, (event: CustomEvent) => void>()
|
||||
)
|
||||
type UploadModelContextResolver = () => UploadModelDialogContext | undefined
|
||||
const mockUploadContext = vi.hoisted(() => ({
|
||||
resolver: undefined as UploadModelContextResolver | undefined
|
||||
}))
|
||||
const mockUploadCallbacks = vi.hoisted(() => ({
|
||||
onUploadSuccess: undefined as
|
||||
| ((result: UploadModelSuccess) => Promise<unknown> | unknown)
|
||||
| undefined
|
||||
}))
|
||||
|
||||
vi.mock('@/scripts/app', () => ({
|
||||
app: {
|
||||
get rootGraph() {
|
||||
return mockRootGraph.value
|
||||
}
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@/scripts/api', () => ({
|
||||
api: {
|
||||
addEventListener: vi.fn(
|
||||
(event: string, handler: (event: CustomEvent) => void) => {
|
||||
mockApiListeners.set(event, handler)
|
||||
}
|
||||
),
|
||||
apiURL: vi.fn((path: string) => path),
|
||||
fetchApi: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@/utils/graphTraversalUtil', async () => {
|
||||
const actual = await vi.importActual<typeof GraphTraversalUtil>(
|
||||
'@/utils/graphTraversalUtil'
|
||||
)
|
||||
return {
|
||||
...actual,
|
||||
getActiveGraphNodeIds: vi.fn(() => new Set()),
|
||||
getNodeByExecutionId: mockGetNodeByExecutionId
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/platform/distribution/types', () => ({
|
||||
get isCloud() {
|
||||
return mockIsCloud.value
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/assets/composables/useModelUpload', () => ({
|
||||
useModelUpload: (
|
||||
onUploadSuccess?: (
|
||||
result: UploadModelSuccess
|
||||
) => Promise<unknown> | unknown,
|
||||
uploadContext?: UploadModelDialogContext | UploadModelContextResolver
|
||||
) => {
|
||||
mockUploadCallbacks.onUploadSuccess = onUploadSuccess
|
||||
mockUploadContext.resolver =
|
||||
typeof uploadContext === 'function' ? uploadContext : () => uploadContext
|
||||
|
||||
return {
|
||||
isUploadButtonEnabled: { value: true },
|
||||
showUploadDialog: mockShowUploadDialog
|
||||
}
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@/composables/useCopyToClipboard', () => ({
|
||||
useCopyToClipboard: () => ({
|
||||
copyToClipboard: mockCopyToClipboard
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/missingModel/missingModelDownload', async () => {
|
||||
const actual = await vi.importActual<typeof MissingModelDownload>(
|
||||
'@/platform/missingModel/missingModelDownload'
|
||||
)
|
||||
return {
|
||||
...actual,
|
||||
downloadModel: mockDownloadModel,
|
||||
fetchModelMetadata: vi.fn().mockResolvedValue({
|
||||
fileSize: null,
|
||||
gatedRepoUrl: null
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
import MissingModelRow from './MissingModelRow.vue'
|
||||
|
||||
const i18n = createI18n({
|
||||
legacy: false,
|
||||
locale: 'en',
|
||||
messages: { en: enMessages },
|
||||
missingWarn: false,
|
||||
fallbackWarn: false
|
||||
})
|
||||
|
||||
const TransitionCollapseStub = {
|
||||
name: 'TransitionCollapse',
|
||||
template: '<div><slot /></div>'
|
||||
}
|
||||
|
||||
function makeModel(
|
||||
refs: MissingModelViewModel['referencingNodes']
|
||||
): MissingModelViewModel {
|
||||
return {
|
||||
name: 'model.safetensors',
|
||||
representative: {
|
||||
nodeId: refs[0]?.nodeId,
|
||||
nodeType: 'CheckpointLoaderSimple',
|
||||
widgetName: 'ckpt_name',
|
||||
isAssetSupported: true,
|
||||
name: 'model.safetensors',
|
||||
directory: 'checkpoints',
|
||||
url: 'https://example.com/model.safetensors',
|
||||
isMissing: true
|
||||
},
|
||||
referencingNodes: refs
|
||||
}
|
||||
}
|
||||
|
||||
function renderRow(
|
||||
model: MissingModelViewModel,
|
||||
onLocateModel = vi.fn(),
|
||||
isAssetSupported = true,
|
||||
directory: string | null = 'checkpoints'
|
||||
) {
|
||||
const pinia = createPinia()
|
||||
setActivePinia(pinia)
|
||||
|
||||
render(MissingModelRow, {
|
||||
props: {
|
||||
model,
|
||||
directory,
|
||||
isAssetSupported,
|
||||
onLocateModel
|
||||
},
|
||||
global: {
|
||||
plugins: [pinia, i18n],
|
||||
stubs: {
|
||||
TransitionCollapse: TransitionCollapseStub
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return { onLocateModel }
|
||||
}
|
||||
|
||||
describe('MissingModelRow', () => {
|
||||
beforeEach(() => {
|
||||
mockIsCloud.value = true
|
||||
mockShowUploadDialog.mockClear()
|
||||
mockCopyToClipboard.mockClear()
|
||||
mockDownloadModel.mockClear()
|
||||
mockRootGraph.value = null
|
||||
mockGetNodeByExecutionId.mockReset()
|
||||
mockUploadContext.resolver = undefined
|
||||
mockUploadCallbacks.onUploadSuccess = undefined
|
||||
})
|
||||
|
||||
it('opens the model import dialog from the cloud row', async () => {
|
||||
const user = userEvent.setup()
|
||||
renderRow(makeModel([{ nodeId: '1', widgetName: 'ckpt_name' }]))
|
||||
|
||||
expect(screen.queryByRole('textbox')).not.toBeInTheDocument()
|
||||
|
||||
await user.click(screen.getByRole('button', { name: 'Import' }))
|
||||
|
||||
expect(mockShowUploadDialog).toHaveBeenCalledTimes(1)
|
||||
expect(mockUploadContext.resolver?.()).toEqual({
|
||||
kind: 'missing-model-resolution',
|
||||
missingModelName: 'model.safetensors',
|
||||
requiredModelType: 'checkpoints',
|
||||
replacementTargets: [
|
||||
{
|
||||
nodeId: '1',
|
||||
nodeLabel: 'CheckpointLoaderSimple',
|
||||
widgetName: 'ckpt_name'
|
||||
}
|
||||
]
|
||||
})
|
||||
})
|
||||
|
||||
it('shows row progress as soon as the model import starts', async () => {
|
||||
renderRow(makeModel([{ nodeId: '1', widgetName: 'ckpt_name' }]))
|
||||
const store = useMissingModelStore()
|
||||
|
||||
await mockUploadCallbacks.onUploadSuccess?.({
|
||||
filename: 'downloaded-model.safetensors',
|
||||
modelType: 'checkpoints',
|
||||
taskId: 'task-1',
|
||||
status: 'processing'
|
||||
})
|
||||
await nextTick()
|
||||
|
||||
expect(
|
||||
store.importTaskIds['supported::checkpoints::model.safetensors']
|
||||
).toBe('task-1')
|
||||
expect(
|
||||
screen.getByRole('progressbar', { name: 'Importing...' })
|
||||
).toBeInTheDocument()
|
||||
expect(screen.getByRole('status')).toHaveTextContent('Importing...')
|
||||
expect(screen.getByText('downloaded-model.safetensors')).toBeInTheDocument()
|
||||
expect(screen.queryByRole('button', { name: 'Import' })).toBeNull()
|
||||
})
|
||||
|
||||
it('applies the completed imported model to every referencing node', async () => {
|
||||
const graph = {}
|
||||
const firstWidget = {
|
||||
name: 'ckpt_name',
|
||||
value: 'old-first.safetensors',
|
||||
callback: vi.fn()
|
||||
}
|
||||
const secondWidget = {
|
||||
name: 'ckpt_name',
|
||||
value: 'old-second.safetensors',
|
||||
callback: vi.fn()
|
||||
}
|
||||
const firstSetDirtyCanvas = vi.fn()
|
||||
const secondSetDirtyCanvas = vi.fn()
|
||||
mockRootGraph.value = graph
|
||||
mockGetNodeByExecutionId.mockImplementation((_graph, nodeId) => {
|
||||
if (nodeId === '1') {
|
||||
return {
|
||||
widgets: [firstWidget],
|
||||
graph: { setDirtyCanvas: firstSetDirtyCanvas }
|
||||
}
|
||||
}
|
||||
if (nodeId === '2') {
|
||||
return {
|
||||
widgets: [secondWidget],
|
||||
graph: { setDirtyCanvas: secondSetDirtyCanvas }
|
||||
}
|
||||
}
|
||||
return null
|
||||
})
|
||||
|
||||
renderRow(
|
||||
makeModel([
|
||||
{ nodeId: '1', widgetName: 'ckpt_name' },
|
||||
{ nodeId: '2', widgetName: 'ckpt_name' }
|
||||
])
|
||||
)
|
||||
|
||||
await mockUploadCallbacks.onUploadSuccess?.({
|
||||
filename: 'client-name.safetensors',
|
||||
modelType: 'checkpoints',
|
||||
taskId: 'task-1',
|
||||
status: 'processing'
|
||||
})
|
||||
await nextTick()
|
||||
|
||||
const handler = mockApiListeners.get('asset_download')
|
||||
expect(handler).toBeDefined()
|
||||
handler!(
|
||||
new CustomEvent('asset_download', {
|
||||
detail: {
|
||||
task_id: 'task-1',
|
||||
asset_name: 'server-name.safetensors',
|
||||
bytes_total: 100,
|
||||
bytes_downloaded: 100,
|
||||
progress: 1,
|
||||
status: 'completed'
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(firstWidget.value).toBe('server-name.safetensors')
|
||||
expect(secondWidget.value).toBe('server-name.safetensors')
|
||||
})
|
||||
expect(firstWidget.callback).toHaveBeenCalledWith('server-name.safetensors')
|
||||
expect(secondWidget.callback).toHaveBeenCalledWith(
|
||||
'server-name.safetensors'
|
||||
)
|
||||
expect(firstSetDirtyCanvas).toHaveBeenCalledWith(true, true)
|
||||
expect(secondSetDirtyCanvas).toHaveBeenCalledWith(true, true)
|
||||
})
|
||||
|
||||
it('locates the parent row directly when a cloud model has one reference', async () => {
|
||||
const user = userEvent.setup()
|
||||
const { onLocateModel } = renderRow(
|
||||
makeModel([{ nodeId: '1', widgetName: 'ckpt_name' }])
|
||||
)
|
||||
|
||||
await user.click(screen.getByRole('button', { name: 'model.safetensors' }))
|
||||
|
||||
expect(onLocateModel).toHaveBeenCalledWith('1')
|
||||
})
|
||||
|
||||
it('moves locate actions to expanded child rows when a cloud model has multiple references', async () => {
|
||||
const user = userEvent.setup()
|
||||
const { onLocateModel } = renderRow(
|
||||
makeModel([
|
||||
{ nodeId: '1', widgetName: 'ckpt_name' },
|
||||
{ nodeId: '2', widgetName: 'ckpt_name' }
|
||||
])
|
||||
)
|
||||
|
||||
expect(screen.getByText('2')).toBeInTheDocument()
|
||||
expect(screen.queryAllByTestId('missing-model-locate')).toHaveLength(0)
|
||||
|
||||
await user.click(
|
||||
screen.getByRole('button', { name: 'Show referencing nodes' })
|
||||
)
|
||||
|
||||
const locateButtons = screen.getAllByTestId('missing-model-locate')
|
||||
expect(locateButtons).toHaveLength(2)
|
||||
|
||||
await user.click(locateButtons[1])
|
||||
|
||||
expect(onLocateModel).toHaveBeenCalledWith('2')
|
||||
})
|
||||
|
||||
it('locates the parent row directly when an OSS model has one reference', async () => {
|
||||
mockIsCloud.value = false
|
||||
const user = userEvent.setup()
|
||||
const { onLocateModel } = renderRow(
|
||||
makeModel([{ nodeId: '1', widgetName: 'ckpt_name' }])
|
||||
)
|
||||
|
||||
await user.click(screen.getByRole('button', { name: 'model.safetensors' }))
|
||||
|
||||
expect(onLocateModel).toHaveBeenCalledWith('1')
|
||||
})
|
||||
|
||||
it('does not show the library selector in OSS rows', () => {
|
||||
mockIsCloud.value = false
|
||||
|
||||
renderRow(makeModel([{ nodeId: '1', widgetName: 'ckpt_name' }]))
|
||||
|
||||
expect(
|
||||
screen.getByPlaceholderText('Paste Model URL (Civitai or Hugging Face)')
|
||||
).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('shows model type metadata below the model name', () => {
|
||||
renderRow(makeModel([{ nodeId: '1', widgetName: 'ckpt_name' }]))
|
||||
|
||||
expect(screen.getByText('checkpoints')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('shows downloadable model size beside the model type metadata', async () => {
|
||||
mockIsCloud.value = false
|
||||
const model = makeModel([{ nodeId: '1', widgetName: 'ckpt_name' }])
|
||||
model.representative.url =
|
||||
'https://huggingface.co/comfy/test/resolve/main/model.safetensors'
|
||||
|
||||
renderRow(model, vi.fn(), false)
|
||||
const store = useMissingModelStore()
|
||||
store.fileSizes[model.representative.url] = 14 * 1024 ** 3
|
||||
await nextTick()
|
||||
|
||||
expect(screen.getByText('checkpoints · 14 GB')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('missing-model-download')).toHaveTextContent(
|
||||
'Download'
|
||||
)
|
||||
})
|
||||
|
||||
it('shows unknown category metadata for models without a directory', () => {
|
||||
renderRow(
|
||||
makeModel([{ nodeId: '1', widgetName: 'ckpt_name' }]),
|
||||
vi.fn(),
|
||||
true,
|
||||
null
|
||||
)
|
||||
|
||||
expect(screen.getByText('Unknown')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('moves locate actions to expanded child rows when an OSS model has multiple references', async () => {
|
||||
mockIsCloud.value = false
|
||||
const user = userEvent.setup()
|
||||
const { onLocateModel } = renderRow(
|
||||
makeModel([
|
||||
{ nodeId: '1', widgetName: 'ckpt_name' },
|
||||
{ nodeId: '2', widgetName: 'ckpt_name' }
|
||||
])
|
||||
)
|
||||
|
||||
expect(screen.getByText('2')).toBeInTheDocument()
|
||||
expect(screen.queryAllByTestId('missing-model-locate')).toHaveLength(0)
|
||||
|
||||
await user.click(
|
||||
screen.getByRole('button', { name: 'Show referencing nodes' })
|
||||
)
|
||||
|
||||
const locateButtons = screen.getAllByTestId('missing-model-locate')
|
||||
expect(locateButtons).toHaveLength(2)
|
||||
|
||||
await user.click(locateButtons[1])
|
||||
|
||||
expect(onLocateModel).toHaveBeenCalledWith('2')
|
||||
})
|
||||
|
||||
it('shows the OSS download action in the row for downloadable models', async () => {
|
||||
mockIsCloud.value = false
|
||||
const user = userEvent.setup()
|
||||
const model = makeModel([{ nodeId: '1', widgetName: 'ckpt_name' }])
|
||||
model.representative.url =
|
||||
'https://huggingface.co/comfy/test/resolve/main/model.safetensors'
|
||||
|
||||
renderRow(model, vi.fn(), false)
|
||||
|
||||
await user.click(screen.getByTestId('missing-model-download'))
|
||||
|
||||
expect(mockDownloadModel).toHaveBeenCalledWith(
|
||||
{
|
||||
name: 'model.safetensors',
|
||||
url: 'https://huggingface.co/comfy/test/resolve/main/model.safetensors',
|
||||
directory: 'checkpoints'
|
||||
},
|
||||
{}
|
||||
)
|
||||
})
|
||||
})
|
||||
@@ -1,11 +1,72 @@
|
||||
<template>
|
||||
<div class="mb-1 flex w-full flex-col gap-0.5 last:mb-0">
|
||||
<div class="flex min-h-8 w-full items-center gap-1">
|
||||
<div class="flex w-full flex-col pb-3">
|
||||
<!-- Model header -->
|
||||
<div class="flex h-8 w-full items-center gap-2">
|
||||
<i
|
||||
aria-hidden="true"
|
||||
class="text-foreground icon-[lucide--file-check] size-4 shrink-0"
|
||||
/>
|
||||
|
||||
<div class="flex min-w-0 flex-1 items-center">
|
||||
<p
|
||||
class="text-foreground min-w-0 truncate text-sm font-medium"
|
||||
:title="model.name"
|
||||
>
|
||||
{{ model.name }} ({{ model.referencingNodes.length }})
|
||||
</p>
|
||||
|
||||
<Button
|
||||
data-testid="missing-model-copy-name"
|
||||
variant="textonly"
|
||||
size="icon-sm"
|
||||
class="size-8 shrink-0 hover:bg-transparent"
|
||||
:aria-label="t('rightSidePanel.missingModels.copyModelName')"
|
||||
:title="t('rightSidePanel.missingModels.copyModelName')"
|
||||
@click="copyToClipboard(model.name)"
|
||||
>
|
||||
<i
|
||||
aria-hidden="true"
|
||||
class="icon-[lucide--copy] size-3.5 text-muted-foreground"
|
||||
/>
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<Button
|
||||
v-if="hasMultipleReferences"
|
||||
v-if="!isCloud && model.representative.url && !isAssetSupported"
|
||||
data-testid="missing-model-copy-url"
|
||||
variant="secondary"
|
||||
size="sm"
|
||||
class="h-8 shrink-0 rounded-lg text-sm"
|
||||
@click="copyToClipboard(toBrowsableUrl(model.representative.url!))"
|
||||
>
|
||||
{{ t('rightSidePanel.missingModels.copyUrl') }}
|
||||
</Button>
|
||||
|
||||
<Button
|
||||
variant="textonly"
|
||||
size="icon-sm"
|
||||
:aria-label="t('rightSidePanel.missingModels.confirmSelection')"
|
||||
:disabled="!canConfirm"
|
||||
:class="
|
||||
cn(
|
||||
'size-8 shrink-0 rounded-lg transition-colors',
|
||||
canConfirm ? 'bg-primary/10 hover:bg-primary/15' : 'opacity-20'
|
||||
)
|
||||
"
|
||||
@click="handleLibrarySelect"
|
||||
>
|
||||
<i
|
||||
aria-hidden="true"
|
||||
class="icon-[lucide--check] size-4"
|
||||
:class="canConfirm ? 'text-primary' : 'text-foreground'"
|
||||
/>
|
||||
</Button>
|
||||
|
||||
<Button
|
||||
v-if="model.referencingNodes.length > 0"
|
||||
data-testid="missing-model-expand"
|
||||
variant="textonly"
|
||||
size="unset"
|
||||
size="icon-sm"
|
||||
:aria-label="
|
||||
expanded
|
||||
? t('rightSidePanel.missingModels.collapseNodes')
|
||||
@@ -14,240 +75,116 @@
|
||||
:aria-expanded="expanded"
|
||||
:class="
|
||||
cn(
|
||||
'h-8 w-4 shrink-0 p-0 transition-transform duration-200 hover:bg-transparent',
|
||||
expanded && 'rotate-90'
|
||||
'size-8 shrink-0 transition-transform duration-200 hover:bg-transparent',
|
||||
expanded && 'rotate-180'
|
||||
)
|
||||
"
|
||||
@click="handleToggleExpand"
|
||||
@click="toggleModelExpand(modelKey)"
|
||||
>
|
||||
<i
|
||||
aria-hidden="true"
|
||||
class="icon-[lucide--chevron-right] size-4 text-muted-foreground"
|
||||
class="icon-[lucide--chevron-down] size-4 text-muted-foreground group-hover:text-base-foreground"
|
||||
/>
|
||||
</Button>
|
||||
|
||||
<span class="flex min-w-0 flex-1 flex-col gap-0">
|
||||
<span class="flex min-w-0 items-center gap-2">
|
||||
<span class="flex min-w-0 items-center gap-2.5">
|
||||
<button
|
||||
v-if="hasMultipleReferences"
|
||||
ref="modelLabelControl"
|
||||
type="button"
|
||||
class="m-0 inline max-w-full cursor-pointer appearance-none border-0 bg-transparent p-0 text-left text-sm/relaxed font-normal wrap-break-word text-base-foreground outline-none hover:text-base-foreground focus:outline-none focus-visible:underline focus-visible:ring-0 focus-visible:outline-none"
|
||||
:title="displayModelName"
|
||||
@click="handleToggleExpand"
|
||||
>
|
||||
{{ displayModelName }}
|
||||
</button>
|
||||
<button
|
||||
v-else-if="!isUnknownCategory && primaryReference"
|
||||
ref="modelLabelControl"
|
||||
type="button"
|
||||
class="m-0 inline max-w-full cursor-pointer appearance-none border-0 bg-transparent p-0 text-left text-sm/relaxed font-normal wrap-break-word text-base-foreground outline-none hover:text-base-foreground focus:outline-none focus-visible:underline focus-visible:ring-0 focus-visible:outline-none"
|
||||
:title="displayModelName"
|
||||
@click="handleLocatePrimary"
|
||||
>
|
||||
{{ displayModelName }}
|
||||
</button>
|
||||
<span
|
||||
v-else
|
||||
class="min-w-0 truncate text-sm/relaxed font-normal text-base-foreground"
|
||||
:title="displayModelName"
|
||||
>
|
||||
{{ displayModelName }}
|
||||
</span>
|
||||
<span
|
||||
v-if="hasMultipleReferences"
|
||||
data-testid="missing-model-reference-count"
|
||||
class="flex size-6 shrink-0 items-center justify-center rounded-md bg-secondary-background-selected text-xs font-bold text-muted-foreground"
|
||||
>
|
||||
{{ model.referencingNodes.length }}
|
||||
</span>
|
||||
</span>
|
||||
<Button
|
||||
variant="textonly"
|
||||
size="icon-sm"
|
||||
class="size-7 shrink-0 text-muted-foreground hover:bg-transparent hover:text-base-foreground"
|
||||
:aria-label="linkLabel"
|
||||
:title="linkLabel"
|
||||
@click="copyModelLink"
|
||||
>
|
||||
<i aria-hidden="true" class="icon-[lucide--link] size-4" />
|
||||
</Button>
|
||||
</span>
|
||||
<span
|
||||
v-if="modelMetadataLabel"
|
||||
class="text-2xs/none"
|
||||
:class="
|
||||
isUnknownCategory
|
||||
? 'text-warning-background'
|
||||
: 'text-muted-foreground'
|
||||
"
|
||||
>
|
||||
{{ modelMetadataLabel }}
|
||||
</span>
|
||||
</span>
|
||||
|
||||
<template v-if="isCloud">
|
||||
<Button
|
||||
v-if="!isCloudImportDownloadActive"
|
||||
data-testid="missing-model-import"
|
||||
variant="secondary"
|
||||
size="sm"
|
||||
class="h-8 shrink-0 rounded-lg text-sm"
|
||||
@click="showUploadDialog"
|
||||
>
|
||||
{{ t('g.import') }}
|
||||
</Button>
|
||||
<div
|
||||
v-else
|
||||
ref="cloudProgress"
|
||||
role="progressbar"
|
||||
:aria-label="t('rightSidePanel.missingModels.importing')"
|
||||
:aria-valuenow="cloudImportProgressPercent"
|
||||
aria-valuemin="0"
|
||||
aria-valuemax="100"
|
||||
tabindex="-1"
|
||||
class="flex h-8 w-16 shrink-0 items-center"
|
||||
>
|
||||
<span
|
||||
class="block h-1.5 w-full overflow-hidden rounded-full bg-secondary-background-selected"
|
||||
>
|
||||
<span
|
||||
class="block h-full rounded-full bg-primary-background transition-all duration-200 ease-linear"
|
||||
:style="{ width: `${cloudImportProgressPercent}%` }"
|
||||
/>
|
||||
</span>
|
||||
</div>
|
||||
<span
|
||||
v-if="isCloudImportDownloadActive"
|
||||
role="status"
|
||||
aria-live="polite"
|
||||
class="sr-only"
|
||||
>
|
||||
{{ t('rightSidePanel.missingModels.importing') }}
|
||||
</span>
|
||||
</template>
|
||||
|
||||
<template v-else>
|
||||
<Button
|
||||
v-if="showDownloadAction"
|
||||
data-testid="missing-model-download"
|
||||
variant="secondary"
|
||||
size="sm"
|
||||
class="h-8 shrink-0 rounded-lg text-sm"
|
||||
:aria-label="`${t('g.download')} ${model.name}`"
|
||||
@click="handleDownload"
|
||||
>
|
||||
{{ t('g.download') }}
|
||||
</Button>
|
||||
<Button
|
||||
v-else-if="showConfirmAction"
|
||||
variant="textonly"
|
||||
size="icon-sm"
|
||||
:aria-label="t('rightSidePanel.missingModels.confirmSelection')"
|
||||
:disabled="!canConfirm"
|
||||
:class="
|
||||
cn(
|
||||
'size-8 shrink-0 rounded-lg transition-colors',
|
||||
canConfirm ? 'bg-primary/10 hover:bg-primary/15' : 'opacity-20'
|
||||
)
|
||||
"
|
||||
@click="handleLibrarySelect"
|
||||
>
|
||||
<i
|
||||
aria-hidden="true"
|
||||
class="icon-[lucide--check] size-4"
|
||||
:class="canConfirm ? 'text-primary' : 'text-foreground'"
|
||||
/>
|
||||
</Button>
|
||||
</template>
|
||||
|
||||
<Button
|
||||
v-if="!hasMultipleReferences && !isUnknownCategory && primaryReference"
|
||||
data-testid="missing-model-locate"
|
||||
variant="textonly"
|
||||
size="icon-sm"
|
||||
:aria-label="t('rightSidePanel.missingModels.locateNode')"
|
||||
class="size-8 shrink-0 text-muted-foreground hover:text-base-foreground"
|
||||
@click="handleLocatePrimary"
|
||||
>
|
||||
<i aria-hidden="true" class="icon-[lucide--locate] size-4" />
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<!-- Referencing nodes -->
|
||||
<TransitionCollapse>
|
||||
<ul
|
||||
v-if="showReferenceList"
|
||||
:class="
|
||||
cn(
|
||||
'm-0 list-none space-y-1 p-0',
|
||||
(hasMultipleReferences || isUnknownCategory) && 'pl-5'
|
||||
)
|
||||
"
|
||||
<div
|
||||
v-if="expanded"
|
||||
class="mb-1 flex flex-col gap-0.5 overflow-hidden pl-6"
|
||||
>
|
||||
<li
|
||||
<div
|
||||
v-for="ref in model.referencingNodes"
|
||||
:key="`${String(ref.nodeId)}::${ref.widgetName}`"
|
||||
class="min-w-0"
|
||||
class="flex h-7 items-center"
|
||||
>
|
||||
<div class="flex min-w-0 items-center gap-2">
|
||||
<button
|
||||
type="button"
|
||||
class="m-0 inline max-w-full cursor-pointer appearance-none border-0 bg-transparent p-0 text-left text-sm/relaxed font-normal wrap-break-word text-muted-foreground outline-none hover:text-base-foreground focus:outline-none focus-visible:underline focus-visible:ring-0 focus-visible:outline-none"
|
||||
@click="emit('locateModel', String(ref.nodeId))"
|
||||
>
|
||||
{{
|
||||
getNodeDisplayLabel(ref.nodeId, model.representative.nodeType)
|
||||
}}
|
||||
</button>
|
||||
<Button
|
||||
data-testid="missing-model-locate"
|
||||
variant="textonly"
|
||||
size="icon-sm"
|
||||
:aria-label="t('rightSidePanel.missingModels.locateNode')"
|
||||
class="ml-auto size-8 shrink-0 text-muted-foreground hover:text-base-foreground"
|
||||
@click="emit('locateModel', String(ref.nodeId))"
|
||||
>
|
||||
<i aria-hidden="true" class="icon-[lucide--locate] size-4" />
|
||||
</Button>
|
||||
</div>
|
||||
</li>
|
||||
</ul>
|
||||
<span
|
||||
v-if="showNodeIdBadge"
|
||||
class="mr-1 shrink-0 rounded-md bg-secondary-background-selected px-2 py-0.5 font-mono text-xs font-bold text-muted-foreground"
|
||||
>
|
||||
#{{ ref.nodeId }}
|
||||
</span>
|
||||
<p class="min-w-0 flex-1 truncate text-xs text-muted-foreground">
|
||||
{{ getNodeDisplayLabel(ref.nodeId, model.representative.nodeType) }}
|
||||
</p>
|
||||
<Button
|
||||
data-testid="missing-model-locate"
|
||||
variant="textonly"
|
||||
size="icon-sm"
|
||||
:aria-label="t('rightSidePanel.missingModels.locateNode')"
|
||||
class="mr-1 size-6 shrink-0 text-muted-foreground hover:text-base-foreground"
|
||||
@click="emit('locateModel', String(ref.nodeId))"
|
||||
>
|
||||
<i aria-hidden="true" class="icon-[lucide--locate] size-3" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</TransitionCollapse>
|
||||
|
||||
<template v-if="!isCloud">
|
||||
<TransitionCollapse>
|
||||
<MissingModelStatusCard
|
||||
v-if="selectedLibraryModel[modelKey]"
|
||||
:model-name="selectedLibraryModel[modelKey]"
|
||||
:is-download-active="isDownloadActive"
|
||||
:download-status="downloadStatus"
|
||||
:category-mismatch="importCategoryMismatch[modelKey]"
|
||||
@cancel="cancelLibrarySelect(modelKey)"
|
||||
/>
|
||||
</TransitionCollapse>
|
||||
<!-- Status card -->
|
||||
<TransitionCollapse>
|
||||
<MissingModelStatusCard
|
||||
v-if="selectedLibraryModel[modelKey]"
|
||||
:model-name="selectedLibraryModel[modelKey]"
|
||||
:is-download-active="isDownloadActive"
|
||||
:download-status="downloadStatus"
|
||||
:category-mismatch="importCategoryMismatch[modelKey]"
|
||||
@cancel="cancelLibrarySelect(modelKey)"
|
||||
/>
|
||||
</TransitionCollapse>
|
||||
|
||||
<TransitionCollapse>
|
||||
<div
|
||||
v-if="!selectedLibraryModel[modelKey]"
|
||||
class="mt-1 flex flex-col gap-1"
|
||||
>
|
||||
<div v-if="isAssetSupported" class="flex w-full flex-col py-1">
|
||||
<MissingModelUrlInput
|
||||
:model-key="modelKey"
|
||||
:directory="directory"
|
||||
:type-mismatch="typeMismatch"
|
||||
/>
|
||||
</div>
|
||||
<!-- Input area -->
|
||||
<TransitionCollapse>
|
||||
<div
|
||||
v-if="!selectedLibraryModel[modelKey]"
|
||||
class="mt-1 flex flex-col gap-1"
|
||||
>
|
||||
<div v-if="isAssetSupported" class="flex w-full flex-col py-1">
|
||||
<MissingModelUrlInput
|
||||
:model-key="modelKey"
|
||||
:directory="directory"
|
||||
:type-mismatch="typeMismatch"
|
||||
/>
|
||||
</div>
|
||||
</TransitionCollapse>
|
||||
</template>
|
||||
<div
|
||||
v-else-if="!isCloud && downloadable"
|
||||
class="flex w-full items-start py-1"
|
||||
>
|
||||
<Button
|
||||
data-testid="missing-model-download"
|
||||
variant="secondary"
|
||||
size="md"
|
||||
class="flex w-full flex-1"
|
||||
:aria-label="`${t('g.download')} ${model.name}`"
|
||||
@click="handleDownload"
|
||||
>
|
||||
<i
|
||||
aria-hidden="true"
|
||||
class="text-foreground mr-1 icon-[lucide--download] size-4 shrink-0"
|
||||
/>
|
||||
<span class="text-foreground min-w-0 truncate text-sm">
|
||||
{{ downloadLabel }}
|
||||
</span>
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<TransitionCollapse>
|
||||
<MissingModelLibrarySelect
|
||||
v-if="!urlInputs[modelKey]"
|
||||
:model-value="getComboValue(model.representative)"
|
||||
:options="comboOptions"
|
||||
:show-divider="isAssetSupported || downloadable"
|
||||
@select="handleComboSelect(modelKey, $event)"
|
||||
/>
|
||||
</TransitionCollapse>
|
||||
</div>
|
||||
</TransitionCollapse>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed, nextTick, onMounted, useTemplateRef, watch } from 'vue'
|
||||
import { computed, onMounted } from 'vue'
|
||||
import { storeToRefs } from 'pinia'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import { cn } from '@comfyorg/tailwind-utils'
|
||||
@@ -255,14 +192,14 @@ import Button from '@/components/ui/button/Button.vue'
|
||||
import TransitionCollapse from '@/components/rightSidePanel/layout/TransitionCollapse.vue'
|
||||
import MissingModelStatusCard from '@/platform/missingModel/components/MissingModelStatusCard.vue'
|
||||
import MissingModelUrlInput from '@/platform/missingModel/components/MissingModelUrlInput.vue'
|
||||
import MissingModelLibrarySelect from '@/platform/missingModel/components/MissingModelLibrarySelect.vue'
|
||||
import type { MissingModelViewModel } from '@/platform/missingModel/types'
|
||||
import type { UploadModelDialogContext } from '@/platform/assets/composables/useUploadModelWizard'
|
||||
|
||||
import { useModelUpload } from '@/platform/assets/composables/useModelUpload'
|
||||
import {
|
||||
useMissingModelInteractions,
|
||||
getModelStateKey,
|
||||
getNodeDisplayLabel
|
||||
getNodeDisplayLabel,
|
||||
getComboValue
|
||||
} from '@/platform/missingModel/composables/useMissingModelInteractions'
|
||||
import { useMissingModelStore } from '@/platform/missingModel/missingModelStore'
|
||||
import { useCopyToClipboard } from '@/composables/useCopyToClipboard'
|
||||
@@ -278,6 +215,7 @@ import { formatSize } from '@/utils/formatUtil'
|
||||
const { model, directory, isAssetSupported } = defineProps<{
|
||||
model: MissingModelViewModel
|
||||
directory: string | null
|
||||
showNodeIdBadge: boolean
|
||||
isAssetSupported: boolean
|
||||
}>()
|
||||
|
||||
@@ -293,123 +231,21 @@ const modelKey = computed(() =>
|
||||
)
|
||||
|
||||
const downloadStatus = computed(() => getDownloadStatus(modelKey.value))
|
||||
const comboOptions = computed(() => getComboOptions(model.representative))
|
||||
const canConfirm = computed(() => isSelectionConfirmable(modelKey.value))
|
||||
const expanded = computed(() => isModelExpanded(modelKey.value))
|
||||
const typeMismatch = computed(() => getTypeMismatch(modelKey.value, directory))
|
||||
const isUnknownCategory = computed(() => directory === null)
|
||||
const isDownloadActive = computed(
|
||||
() =>
|
||||
downloadStatus.value?.status === 'running' ||
|
||||
downloadStatus.value?.status === 'created'
|
||||
)
|
||||
const isCloudImportDownloadActive = computed(
|
||||
() => isCloud && isDownloadActive.value
|
||||
)
|
||||
const cloudImportProgressPercent = computed(() =>
|
||||
Math.round((downloadStatus.value?.progress ?? 0) * 100)
|
||||
)
|
||||
const hasMultipleReferences = computed(() => model.referencingNodes.length > 1)
|
||||
const primaryReference = computed(() => model.referencingNodes[0])
|
||||
const linkLabel = computed(() =>
|
||||
model.representative.url
|
||||
? t('rightSidePanel.missingModels.copyUrl')
|
||||
: t('rightSidePanel.missingModels.copyModelName')
|
||||
)
|
||||
|
||||
const store = useMissingModelStore()
|
||||
const { selectedLibraryModel, importCategoryMismatch } = storeToRefs(store)
|
||||
const cloudProgress = useTemplateRef<HTMLElement>('cloudProgress')
|
||||
const modelLabelControl = useTemplateRef<HTMLButtonElement>('modelLabelControl')
|
||||
|
||||
const expanded = computed(
|
||||
() =>
|
||||
store.modelExpandState[modelKey.value] ??
|
||||
(isUnknownCategory.value && hasMultipleReferences.value)
|
||||
)
|
||||
const showReferenceList = computed(
|
||||
() =>
|
||||
(isUnknownCategory.value && model.referencingNodes.length === 1) ||
|
||||
(hasMultipleReferences.value && expanded.value)
|
||||
)
|
||||
|
||||
const displayModelName = computed(() => {
|
||||
if (!isCloudImportDownloadActive.value) return model.name
|
||||
|
||||
return (
|
||||
downloadStatus.value?.assetName ??
|
||||
selectedLibraryModel.value[modelKey.value] ??
|
||||
model.name
|
||||
)
|
||||
})
|
||||
|
||||
const downloadable = computed(() => {
|
||||
const rep = model.representative
|
||||
return !!(
|
||||
!isAssetSupported &&
|
||||
rep.url &&
|
||||
rep.directory &&
|
||||
isModelDownloadable({
|
||||
name: rep.name,
|
||||
url: rep.url,
|
||||
directory: rep.directory
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
const showDownloadAction = computed(
|
||||
() =>
|
||||
!isCloud &&
|
||||
downloadable.value &&
|
||||
!selectedLibraryModel.value[modelKey.value]
|
||||
)
|
||||
const showConfirmAction = computed(
|
||||
() => !isCloud && !!selectedLibraryModel.value[modelKey.value]
|
||||
)
|
||||
|
||||
const downloadSizeLabel = computed(() => {
|
||||
if (!showDownloadAction.value) return undefined
|
||||
|
||||
const url = model.representative.url
|
||||
const size = url ? store.fileSizes[url] : undefined
|
||||
return size ? formatSize(size) : undefined
|
||||
})
|
||||
const modelTypeLabel = computed(
|
||||
() => directory ?? t('rightSidePanel.missingModels.unknownCategory')
|
||||
)
|
||||
const modelMetadataLabel = computed(() =>
|
||||
[modelTypeLabel.value, downloadSizeLabel.value].filter(Boolean).join(' · ')
|
||||
)
|
||||
|
||||
const missingModelUploadContext = computed<
|
||||
UploadModelDialogContext | undefined
|
||||
>(() => {
|
||||
if (!directory) return undefined
|
||||
|
||||
return {
|
||||
kind: 'missing-model-resolution',
|
||||
missingModelName: model.name,
|
||||
requiredModelType: directory,
|
||||
replacementTargets: model.referencingNodes.map((ref) => ({
|
||||
nodeId: String(ref.nodeId),
|
||||
nodeLabel: getNodeDisplayLabel(ref.nodeId, model.representative.nodeType),
|
||||
widgetName: ref.widgetName
|
||||
}))
|
||||
}
|
||||
})
|
||||
|
||||
const { showUploadDialog } = useModelUpload(
|
||||
(result) => {
|
||||
handleUploadedModelImport(modelKey.value, result)
|
||||
|
||||
if (result.status === 'success') {
|
||||
handleLibrarySelect()
|
||||
}
|
||||
},
|
||||
() => missingModelUploadContext.value
|
||||
)
|
||||
const { selectedLibraryModel, importCategoryMismatch, urlInputs } =
|
||||
storeToRefs(store)
|
||||
|
||||
onMounted(() => {
|
||||
if (isCloud) return
|
||||
|
||||
const url = model.representative.url
|
||||
if (url && !store.fileSizes[url]) {
|
||||
fetchModelMetadata(url)
|
||||
@@ -427,6 +263,27 @@ onMounted(() => {
|
||||
}
|
||||
})
|
||||
|
||||
const downloadable = computed(() => {
|
||||
const rep = model.representative
|
||||
return !!(
|
||||
!isAssetSupported &&
|
||||
rep.url &&
|
||||
rep.directory &&
|
||||
isModelDownloadable({
|
||||
name: rep.name,
|
||||
url: rep.url,
|
||||
directory: rep.directory
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
const downloadLabel = computed(() => {
|
||||
const base = t('g.download')
|
||||
const url = model.representative.url
|
||||
const size = url ? store.fileSizes[url] : undefined
|
||||
return size ? `${base} (${formatSize(size)})` : base
|
||||
})
|
||||
|
||||
function handleDownload() {
|
||||
const rep = model.representative
|
||||
if (rep.url && rep.directory) {
|
||||
@@ -439,51 +296,18 @@ function handleDownload() {
|
||||
}
|
||||
}
|
||||
|
||||
function handleLocatePrimary() {
|
||||
const ref = primaryReference.value
|
||||
if (ref) emit('locateModel', String(ref.nodeId))
|
||||
}
|
||||
|
||||
function copyModelLink() {
|
||||
const url = model.representative.url
|
||||
copyToClipboard(url ? toBrowsableUrl(url) : model.name)
|
||||
}
|
||||
|
||||
const {
|
||||
toggleModelExpand,
|
||||
isModelExpanded,
|
||||
getComboOptions,
|
||||
handleComboSelect,
|
||||
isSelectionConfirmable,
|
||||
cancelLibrarySelect,
|
||||
confirmLibrarySelect,
|
||||
getTypeMismatch,
|
||||
getDownloadStatus,
|
||||
handleUploadedModelImport
|
||||
getDownloadStatus
|
||||
} = useMissingModelInteractions()
|
||||
|
||||
function handleToggleExpand() {
|
||||
store.modelExpandState[modelKey.value] = !expanded.value
|
||||
}
|
||||
|
||||
watch(
|
||||
() => downloadStatus.value?.status,
|
||||
(status) => {
|
||||
if (!isCloud || status !== 'completed') return
|
||||
const completedAssetName = downloadStatus.value?.assetName
|
||||
if (completedAssetName) {
|
||||
selectedLibraryModel.value[modelKey.value] = completedAssetName
|
||||
}
|
||||
handleLibrarySelect()
|
||||
},
|
||||
{ immediate: true }
|
||||
)
|
||||
|
||||
watch(isCloudImportDownloadActive, async (isActive, wasActive) => {
|
||||
await nextTick()
|
||||
if (isActive) {
|
||||
cloudProgress.value?.focus()
|
||||
} else if (wasActive) {
|
||||
modelLabelControl.value?.focus()
|
||||
}
|
||||
})
|
||||
|
||||
function handleLibrarySelect() {
|
||||
confirmLibrarySelect(
|
||||
modelKey.value,
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
aria-live="polite"
|
||||
class="bg-foreground/5 relative mt-1 overflow-hidden rounded-lg border border-interface-stroke p-2"
|
||||
>
|
||||
<!-- Progress bar fill -->
|
||||
<div
|
||||
v-if="isDownloadActive"
|
||||
class="absolute inset-y-0 left-0 bg-primary/10 transition-all duration-200 ease-linear"
|
||||
@@ -64,7 +65,7 @@
|
||||
}}
|
||||
</template>
|
||||
<template v-else>
|
||||
{{ t('rightSidePanel.missingModels.readyToApply') }}
|
||||
{{ t('rightSidePanel.missingModels.usingFromLibrary') }}
|
||||
</template>
|
||||
</span>
|
||||
</div>
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
import { createPinia, setActivePinia } from 'pinia'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { createApp } from 'vue'
|
||||
import type { App } from 'vue'
|
||||
import { createI18n } from 'vue-i18n'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import enMessages from '@/locales/en/main.json' with { type: 'json' }
|
||||
import type { MissingModelCandidate } from '@/platform/missingModel/types'
|
||||
|
||||
const mockGetNodeByExecutionId = vi.fn()
|
||||
@@ -14,16 +10,29 @@ const mockGetAssetMetadata = vi.fn()
|
||||
const mockUploadAssetAsync = vi.fn()
|
||||
const mockTrackDownload = vi.fn()
|
||||
const mockInvalidateModelsForCategory = vi.fn()
|
||||
const mockGetAssetDisplayName = vi.fn((a: { name: string }) => a.name)
|
||||
const mockGetAssetFilename = vi.fn((a: { name: string }) => a.name)
|
||||
const mockGetAssets = vi.fn()
|
||||
const mockUpdateModelsForNodeType = vi.fn()
|
||||
const mockGetAllNodeProviders = vi.fn()
|
||||
const mockDownloadList = vi.fn(
|
||||
(): Array<{ taskId: string; status: string }> => []
|
||||
)
|
||||
|
||||
vi.mock('@/i18n', () => ({
|
||||
st: vi.fn((_key: string, fallback: string) => fallback)
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/distribution/types', () => ({
|
||||
isCloud: false
|
||||
}))
|
||||
|
||||
vi.mock('vue-i18n', () => ({
|
||||
useI18n: () => ({
|
||||
t: (key: string) => key
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/scripts/app', () => ({
|
||||
app: {
|
||||
rootGraph: null
|
||||
@@ -46,6 +55,7 @@ vi.mock('@/renderer/core/canvas/canvasStore', () => ({
|
||||
|
||||
vi.mock('@/stores/assetsStore', () => ({
|
||||
useAssetsStore: () => ({
|
||||
getAssets: mockGetAssets,
|
||||
updateModelsForNodeType: mockUpdateModelsForNodeType,
|
||||
invalidateModelsForCategory: mockInvalidateModelsForCategory,
|
||||
updateModelsForTag: vi.fn()
|
||||
@@ -74,6 +84,11 @@ vi.mock('@/platform/assets/services/assetService', () => ({
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/assets/utils/assetMetadataUtils', () => ({
|
||||
getAssetDisplayName: (a: { name: string }) => mockGetAssetDisplayName(a),
|
||||
getAssetFilename: (a: { name: string }) => mockGetAssetFilename(a)
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/assets/importSources/civitaiImportSource', () => ({
|
||||
civitaiImportSource: {
|
||||
type: 'civitai',
|
||||
@@ -97,6 +112,7 @@ vi.mock('@/platform/assets/utils/importSourceUtil', () => ({
|
||||
import { app } from '@/scripts/app'
|
||||
import { useMissingModelStore } from '@/platform/missingModel/missingModelStore'
|
||||
import {
|
||||
getComboValue,
|
||||
getModelStateKey,
|
||||
getNodeDisplayLabel,
|
||||
useMissingModelInteractions
|
||||
@@ -117,54 +133,17 @@ function makeCandidate(
|
||||
}
|
||||
|
||||
describe('useMissingModelInteractions', () => {
|
||||
const mountedApps: App<Element>[] = []
|
||||
|
||||
function setupWithI18n<T>(factory: () => T): T {
|
||||
let result: T | undefined
|
||||
const host = document.createElement('div')
|
||||
const app = createApp({
|
||||
setup() {
|
||||
result = factory()
|
||||
return () => null
|
||||
}
|
||||
})
|
||||
app.use(
|
||||
createI18n({
|
||||
legacy: false,
|
||||
locale: 'en',
|
||||
messages: { en: enMessages }
|
||||
})
|
||||
)
|
||||
app.mount(host)
|
||||
mountedApps.push(app)
|
||||
|
||||
if (result === undefined) {
|
||||
throw new Error('Composable setup did not run')
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
function setupMissingModelInteractions(): ReturnType<
|
||||
typeof useMissingModelInteractions
|
||||
> {
|
||||
return setupWithI18n(() => useMissingModelInteractions())
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
setActivePinia(createPinia())
|
||||
vi.resetAllMocks()
|
||||
mockGetAssetDisplayName.mockImplementation((a: { name: string }) => a.name)
|
||||
mockGetAssetFilename.mockImplementation((a: { name: string }) => a.name)
|
||||
mockDownloadList.mockImplementation(
|
||||
(): Array<{ taskId: string; status: string }> => []
|
||||
)
|
||||
;(app as { rootGraph: unknown }).rootGraph = null
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
for (const app of mountedApps.splice(0)) {
|
||||
app.unmount()
|
||||
}
|
||||
})
|
||||
|
||||
describe('getModelStateKey', () => {
|
||||
it('returns key with supported prefix when asset is supported', () => {
|
||||
expect(getModelStateKey('model.safetensors', 'checkpoints', true)).toBe(
|
||||
@@ -205,31 +184,101 @@ describe('useMissingModelInteractions', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('getComboValue', () => {
|
||||
it('returns undefined when node is not found', () => {
|
||||
;(app as { rootGraph: unknown }).rootGraph = {}
|
||||
mockGetNodeByExecutionId.mockReturnValue(null)
|
||||
|
||||
const result = getComboValue(makeCandidate())
|
||||
expect(result).toBeUndefined()
|
||||
})
|
||||
|
||||
it('returns undefined when widget is not found', () => {
|
||||
;(app as { rootGraph: unknown }).rootGraph = {}
|
||||
mockGetNodeByExecutionId.mockReturnValue({
|
||||
widgets: [{ name: 'other_widget', value: 'test' }]
|
||||
})
|
||||
|
||||
const result = getComboValue(makeCandidate())
|
||||
expect(result).toBeUndefined()
|
||||
})
|
||||
|
||||
it('returns string value directly', () => {
|
||||
;(app as { rootGraph: unknown }).rootGraph = {}
|
||||
mockGetNodeByExecutionId.mockReturnValue({
|
||||
widgets: [{ name: 'ckpt_name', value: 'v1-5.safetensors' }]
|
||||
})
|
||||
|
||||
expect(getComboValue(makeCandidate())).toBe('v1-5.safetensors')
|
||||
})
|
||||
|
||||
it('returns stringified number value', () => {
|
||||
;(app as { rootGraph: unknown }).rootGraph = {}
|
||||
mockGetNodeByExecutionId.mockReturnValue({
|
||||
widgets: [{ name: 'ckpt_name', value: 42 }]
|
||||
})
|
||||
|
||||
expect(getComboValue(makeCandidate())).toBe('42')
|
||||
})
|
||||
|
||||
it('returns undefined for unexpected types', () => {
|
||||
;(app as { rootGraph: unknown }).rootGraph = {}
|
||||
mockGetNodeByExecutionId.mockReturnValue({
|
||||
widgets: [{ name: 'ckpt_name', value: { complex: true } }]
|
||||
})
|
||||
|
||||
expect(getComboValue(makeCandidate())).toBeUndefined()
|
||||
})
|
||||
|
||||
it('returns undefined when nodeId is null', () => {
|
||||
const result = getComboValue(makeCandidate({ nodeId: undefined }))
|
||||
expect(result).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('toggleModelExpand / isModelExpanded', () => {
|
||||
it('starts collapsed by default', () => {
|
||||
const { isModelExpanded } = setupMissingModelInteractions()
|
||||
const { isModelExpanded } = useMissingModelInteractions()
|
||||
expect(isModelExpanded('key1')).toBe(false)
|
||||
})
|
||||
|
||||
it('toggles to expanded', () => {
|
||||
const { toggleModelExpand, isModelExpanded } =
|
||||
setupMissingModelInteractions()
|
||||
useMissingModelInteractions()
|
||||
toggleModelExpand('key1')
|
||||
expect(isModelExpanded('key1')).toBe(true)
|
||||
})
|
||||
|
||||
it('toggles back to collapsed', () => {
|
||||
const { toggleModelExpand, isModelExpanded } =
|
||||
setupMissingModelInteractions()
|
||||
useMissingModelInteractions()
|
||||
toggleModelExpand('key1')
|
||||
toggleModelExpand('key1')
|
||||
expect(isModelExpanded('key1')).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('handleComboSelect', () => {
|
||||
it('sets selectedLibraryModel in store', () => {
|
||||
const store = useMissingModelStore()
|
||||
const { handleComboSelect } = useMissingModelInteractions()
|
||||
|
||||
handleComboSelect('key1', 'model_v2.safetensors')
|
||||
expect(store.selectedLibraryModel['key1']).toBe('model_v2.safetensors')
|
||||
})
|
||||
|
||||
it('does not set value when undefined', () => {
|
||||
const store = useMissingModelStore()
|
||||
const { handleComboSelect } = useMissingModelInteractions()
|
||||
|
||||
handleComboSelect('key1', undefined)
|
||||
expect(store.selectedLibraryModel['key1']).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('isSelectionConfirmable', () => {
|
||||
it('returns false when no selection exists', () => {
|
||||
const { isSelectionConfirmable } = setupMissingModelInteractions()
|
||||
const { isSelectionConfirmable } = useMissingModelInteractions()
|
||||
expect(isSelectionConfirmable('key1')).toBe(false)
|
||||
})
|
||||
|
||||
@@ -241,7 +290,7 @@ describe('useMissingModelInteractions', () => {
|
||||
{ taskId: 'task-123', status: 'running' }
|
||||
])
|
||||
|
||||
const { isSelectionConfirmable } = setupMissingModelInteractions()
|
||||
const { isSelectionConfirmable } = useMissingModelInteractions()
|
||||
expect(isSelectionConfirmable('key1')).toBe(false)
|
||||
})
|
||||
|
||||
@@ -250,7 +299,7 @@ describe('useMissingModelInteractions', () => {
|
||||
store.selectedLibraryModel['key1'] = 'model.safetensors'
|
||||
store.importCategoryMismatch['key1'] = 'loras'
|
||||
|
||||
const { isSelectionConfirmable } = setupMissingModelInteractions()
|
||||
const { isSelectionConfirmable } = useMissingModelInteractions()
|
||||
expect(isSelectionConfirmable('key1')).toBe(false)
|
||||
})
|
||||
|
||||
@@ -259,7 +308,7 @@ describe('useMissingModelInteractions', () => {
|
||||
store.selectedLibraryModel['key1'] = 'model.safetensors'
|
||||
mockDownloadList.mockReturnValue([])
|
||||
|
||||
const { isSelectionConfirmable } = setupMissingModelInteractions()
|
||||
const { isSelectionConfirmable } = useMissingModelInteractions()
|
||||
expect(isSelectionConfirmable('key1')).toBe(true)
|
||||
})
|
||||
})
|
||||
@@ -269,14 +318,12 @@ describe('useMissingModelInteractions', () => {
|
||||
const store = useMissingModelStore()
|
||||
store.selectedLibraryModel['key1'] = 'model.safetensors'
|
||||
store.importCategoryMismatch['key1'] = 'loras'
|
||||
store.importTaskIds['key1'] = 'task-123'
|
||||
|
||||
const { cancelLibrarySelect } = setupMissingModelInteractions()
|
||||
const { cancelLibrarySelect } = useMissingModelInteractions()
|
||||
cancelLibrarySelect('key1')
|
||||
|
||||
expect(store.selectedLibraryModel['key1']).toBeUndefined()
|
||||
expect(store.importCategoryMismatch['key1']).toBeUndefined()
|
||||
expect(store.importTaskIds['key1']).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -300,7 +347,6 @@ describe('useMissingModelInteractions', () => {
|
||||
|
||||
const store = useMissingModelStore()
|
||||
store.selectedLibraryModel['key1'] = 'new_model.safetensors'
|
||||
store.importTaskIds['key1'] = 'task-123'
|
||||
store.setMissingModels([
|
||||
makeCandidate({ name: 'old_model.safetensors', nodeId: '10' }),
|
||||
makeCandidate({ name: 'old_model.safetensors', nodeId: '20' })
|
||||
@@ -308,7 +354,7 @@ describe('useMissingModelInteractions', () => {
|
||||
|
||||
const removeSpy = vi.spyOn(store, 'removeMissingModelByNameOnNodes')
|
||||
|
||||
const { confirmLibrarySelect } = setupMissingModelInteractions()
|
||||
const { confirmLibrarySelect } = useMissingModelInteractions()
|
||||
confirmLibrarySelect(
|
||||
'key1',
|
||||
'old_model.safetensors',
|
||||
@@ -326,7 +372,6 @@ describe('useMissingModelInteractions', () => {
|
||||
new Set(['10', '20'])
|
||||
)
|
||||
expect(store.selectedLibraryModel['key1']).toBeUndefined()
|
||||
expect(store.importTaskIds['key1']).toBeUndefined()
|
||||
})
|
||||
|
||||
it('does nothing when no selection exists', () => {
|
||||
@@ -334,7 +379,7 @@ describe('useMissingModelInteractions', () => {
|
||||
const store = useMissingModelStore()
|
||||
const removeSpy = vi.spyOn(store, 'removeMissingModelByNameOnNodes')
|
||||
|
||||
const { confirmLibrarySelect } = setupMissingModelInteractions()
|
||||
const { confirmLibrarySelect } = useMissingModelInteractions()
|
||||
confirmLibrarySelect('key1', 'model.safetensors', [], null)
|
||||
|
||||
expect(removeSpy).not.toHaveBeenCalled()
|
||||
@@ -346,7 +391,7 @@ describe('useMissingModelInteractions', () => {
|
||||
store.selectedLibraryModel['key1'] = 'new.safetensors'
|
||||
const removeSpy = vi.spyOn(store, 'removeMissingModelByNameOnNodes')
|
||||
|
||||
const { confirmLibrarySelect } = setupMissingModelInteractions()
|
||||
const { confirmLibrarySelect } = useMissingModelInteractions()
|
||||
confirmLibrarySelect('key1', 'model.safetensors', [], null)
|
||||
|
||||
expect(removeSpy).not.toHaveBeenCalled()
|
||||
@@ -362,7 +407,7 @@ describe('useMissingModelInteractions', () => {
|
||||
const store = useMissingModelStore()
|
||||
store.selectedLibraryModel['key1'] = 'new.safetensors'
|
||||
|
||||
const { confirmLibrarySelect } = setupMissingModelInteractions()
|
||||
const { confirmLibrarySelect } = useMissingModelInteractions()
|
||||
confirmLibrarySelect('key1', 'model.safetensors', [], 'checkpoints')
|
||||
|
||||
expect(mockGetAllNodeProviders).toHaveBeenCalledWith('checkpoints')
|
||||
@@ -376,7 +421,7 @@ describe('useMissingModelInteractions', () => {
|
||||
store.urlErrors['key1'] = 'old error'
|
||||
store.urlFetching['key1'] = true
|
||||
|
||||
const { handleUrlInput } = setupMissingModelInteractions()
|
||||
const { handleUrlInput } = useMissingModelInteractions()
|
||||
handleUrlInput('key1', 'https://civitai.com/models/123')
|
||||
|
||||
expect(store.urlInputs['key1']).toBe('https://civitai.com/models/123')
|
||||
@@ -389,7 +434,7 @@ describe('useMissingModelInteractions', () => {
|
||||
const store = useMissingModelStore()
|
||||
const setTimerSpy = vi.spyOn(store, 'setDebounceTimer')
|
||||
|
||||
const { handleUrlInput } = setupMissingModelInteractions()
|
||||
const { handleUrlInput } = useMissingModelInteractions()
|
||||
handleUrlInput('key1', ' ')
|
||||
|
||||
expect(setTimerSpy).not.toHaveBeenCalled()
|
||||
@@ -399,7 +444,7 @@ describe('useMissingModelInteractions', () => {
|
||||
const store = useMissingModelStore()
|
||||
const setTimerSpy = vi.spyOn(store, 'setDebounceTimer')
|
||||
|
||||
const { handleUrlInput } = setupMissingModelInteractions()
|
||||
const { handleUrlInput } = useMissingModelInteractions()
|
||||
handleUrlInput('key1', 'https://civitai.com/models/123')
|
||||
|
||||
expect(setTimerSpy).toHaveBeenCalledWith(
|
||||
@@ -413,7 +458,7 @@ describe('useMissingModelInteractions', () => {
|
||||
const store = useMissingModelStore()
|
||||
const clearTimerSpy = vi.spyOn(store, 'clearDebounceTimer')
|
||||
|
||||
const { handleUrlInput } = setupMissingModelInteractions()
|
||||
const { handleUrlInput } = useMissingModelInteractions()
|
||||
handleUrlInput('key1', 'https://civitai.com/models/123')
|
||||
|
||||
expect(clearTimerSpy).toHaveBeenCalledWith('key1')
|
||||
@@ -422,12 +467,12 @@ describe('useMissingModelInteractions', () => {
|
||||
|
||||
describe('getTypeMismatch', () => {
|
||||
it('returns null when groupDirectory is null', () => {
|
||||
const { getTypeMismatch } = setupMissingModelInteractions()
|
||||
const { getTypeMismatch } = useMissingModelInteractions()
|
||||
expect(getTypeMismatch('key1', null)).toBeNull()
|
||||
})
|
||||
|
||||
it('returns null when no metadata exists', () => {
|
||||
const { getTypeMismatch } = setupMissingModelInteractions()
|
||||
const { getTypeMismatch } = useMissingModelInteractions()
|
||||
expect(getTypeMismatch('key1', 'checkpoints')).toBeNull()
|
||||
})
|
||||
|
||||
@@ -435,7 +480,7 @@ describe('useMissingModelInteractions', () => {
|
||||
const store = useMissingModelStore()
|
||||
store.urlMetadata['key1'] = { name: 'model', tags: [] } as never
|
||||
|
||||
const { getTypeMismatch } = setupMissingModelInteractions()
|
||||
const { getTypeMismatch } = useMissingModelInteractions()
|
||||
expect(getTypeMismatch('key1', 'checkpoints')).toBeNull()
|
||||
})
|
||||
|
||||
@@ -446,7 +491,7 @@ describe('useMissingModelInteractions', () => {
|
||||
tags: ['checkpoints']
|
||||
} as never
|
||||
|
||||
const { getTypeMismatch } = setupMissingModelInteractions()
|
||||
const { getTypeMismatch } = useMissingModelInteractions()
|
||||
expect(getTypeMismatch('key1', 'checkpoints')).toBeNull()
|
||||
})
|
||||
|
||||
@@ -457,7 +502,7 @@ describe('useMissingModelInteractions', () => {
|
||||
tags: ['loras']
|
||||
} as never
|
||||
|
||||
const { getTypeMismatch } = setupMissingModelInteractions()
|
||||
const { getTypeMismatch } = useMissingModelInteractions()
|
||||
expect(getTypeMismatch('key1', 'checkpoints')).toBe('loras')
|
||||
})
|
||||
|
||||
@@ -468,14 +513,63 @@ describe('useMissingModelInteractions', () => {
|
||||
tags: ['other', 'random']
|
||||
} as never
|
||||
|
||||
const { getTypeMismatch } = setupMissingModelInteractions()
|
||||
const { getTypeMismatch } = useMissingModelInteractions()
|
||||
expect(getTypeMismatch('key1', 'checkpoints')).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('getComboOptions', () => {
|
||||
it('returns assets from assetsStore when the model is asset-supported', () => {
|
||||
mockGetAssets.mockReturnValueOnce([
|
||||
{ name: 'modelA.safetensors' },
|
||||
{ name: 'modelB.safetensors' }
|
||||
])
|
||||
|
||||
const { getComboOptions } = useMissingModelInteractions()
|
||||
const options = getComboOptions(makeCandidate({ isAssetSupported: true }))
|
||||
|
||||
expect(mockGetAssets).toHaveBeenCalledWith('CheckpointLoaderSimple')
|
||||
expect(options).toEqual([
|
||||
{ name: 'modelA.safetensors', value: 'modelA.safetensors' },
|
||||
{ name: 'modelB.safetensors', value: 'modelB.safetensors' }
|
||||
])
|
||||
})
|
||||
|
||||
it('returns widget options when the model is not asset-supported', () => {
|
||||
;(app as { rootGraph: unknown }).rootGraph = {}
|
||||
mockGetNodeByExecutionId.mockReturnValue({
|
||||
widgets: [
|
||||
{
|
||||
name: 'ckpt_name',
|
||||
value: '',
|
||||
options: { values: ['v1.safetensors', 'v2.safetensors'] }
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
const { getComboOptions } = useMissingModelInteractions()
|
||||
const options = getComboOptions(makeCandidate())
|
||||
|
||||
expect(options).toEqual([
|
||||
{ name: 'v1.safetensors', value: 'v1.safetensors' },
|
||||
{ name: 'v2.safetensors', value: 'v2.safetensors' }
|
||||
])
|
||||
})
|
||||
|
||||
it('returns an empty array when the widget has no options.values', () => {
|
||||
;(app as { rootGraph: unknown }).rootGraph = {}
|
||||
mockGetNodeByExecutionId.mockReturnValue({
|
||||
widgets: [{ name: 'ckpt_name', value: '' }]
|
||||
})
|
||||
|
||||
const { getComboOptions } = useMissingModelInteractions()
|
||||
expect(getComboOptions(makeCandidate())).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
describe('getDownloadStatus', () => {
|
||||
it('returns null when no taskId is tracked for the key', () => {
|
||||
const { getDownloadStatus } = setupMissingModelInteractions()
|
||||
const { getDownloadStatus } = useMissingModelInteractions()
|
||||
expect(getDownloadStatus('key1')).toBeNull()
|
||||
})
|
||||
|
||||
@@ -487,7 +581,7 @@ describe('useMissingModelInteractions', () => {
|
||||
{ taskId: 'task-42', status: 'created' }
|
||||
])
|
||||
|
||||
const { getDownloadStatus } = setupMissingModelInteractions()
|
||||
const { getDownloadStatus } = useMissingModelInteractions()
|
||||
expect(getDownloadStatus('key1')).toEqual({
|
||||
taskId: 'task-42',
|
||||
status: 'created'
|
||||
@@ -514,7 +608,7 @@ describe('useMissingModelInteractions', () => {
|
||||
task: { task_id: 'task-99', status: 'created' }
|
||||
})
|
||||
|
||||
const { handleImport } = setupMissingModelInteractions()
|
||||
const { handleImport } = useMissingModelInteractions()
|
||||
await handleImport('key1', 'checkpoints')
|
||||
|
||||
expect(store.importTaskIds['key1']).toBe('task-99')
|
||||
@@ -532,7 +626,7 @@ describe('useMissingModelInteractions', () => {
|
||||
task: { task_id: 'task-100', status: 'completed' }
|
||||
})
|
||||
|
||||
const { handleImport } = setupMissingModelInteractions()
|
||||
const { handleImport } = useMissingModelInteractions()
|
||||
await handleImport('key1', 'checkpoints')
|
||||
|
||||
expect(mockInvalidateModelsForCategory).toHaveBeenCalledWith(
|
||||
@@ -547,7 +641,7 @@ describe('useMissingModelInteractions', () => {
|
||||
asset: { tags: ['models', 'loras'] }
|
||||
})
|
||||
|
||||
const { handleImport } = setupMissingModelInteractions()
|
||||
const { handleImport } = useMissingModelInteractions()
|
||||
await handleImport('key1', 'checkpoints')
|
||||
|
||||
expect(store.importCategoryMismatch['key1']).toBe('loras')
|
||||
@@ -557,7 +651,7 @@ describe('useMissingModelInteractions', () => {
|
||||
const store = setupImportableState('key1')
|
||||
mockUploadAssetAsync.mockRejectedValueOnce(new Error('Upload boom'))
|
||||
|
||||
const { handleImport } = setupMissingModelInteractions()
|
||||
const { handleImport } = useMissingModelInteractions()
|
||||
await handleImport('key1', 'checkpoints')
|
||||
|
||||
expect(store.urlErrors['key1']).toBe('Upload boom')
|
||||
|
||||
@@ -3,9 +3,12 @@ import { useI18n } from 'vue-i18n'
|
||||
import { resolveNodeDisplayName } from '@/utils/nodeTitleUtil'
|
||||
import { st } from '@/i18n'
|
||||
import { assetService } from '@/platform/assets/services/assetService'
|
||||
import {
|
||||
getAssetDisplayName,
|
||||
getAssetFilename
|
||||
} from '@/platform/assets/utils/assetMetadataUtils'
|
||||
import { civitaiImportSource } from '@/platform/assets/importSources/civitaiImportSource'
|
||||
import { huggingfaceImportSource } from '@/platform/assets/importSources/huggingfaceImportSource'
|
||||
import type { UploadModelSuccess } from '@/platform/assets/composables/useUploadModelWizard'
|
||||
import { validateSourceUrl } from '@/platform/assets/utils/importSourceUtil'
|
||||
import { useMissingModelStore } from '@/platform/missingModel/missingModelStore'
|
||||
import { useAssetsStore } from '@/stores/assetsStore'
|
||||
@@ -13,7 +16,12 @@ import { useAssetDownloadStore } from '@/stores/assetDownloadStore'
|
||||
import { useModelToNodeStore } from '@/stores/modelToNodeStore'
|
||||
import { app } from '@/scripts/app'
|
||||
import { getNodeByExecutionId } from '@/utils/graphTraversalUtil'
|
||||
import type { MissingModelViewModel } from '@/platform/missingModel/types'
|
||||
import type {
|
||||
MissingModelCandidate,
|
||||
MissingModelViewModel
|
||||
} from '@/platform/missingModel/types'
|
||||
import type { LGraphNode } from '@/lib/litegraph/src/litegraph'
|
||||
import type { IBaseWidget } from '@/lib/litegraph/src/types/widgets'
|
||||
|
||||
const importSources = [civitaiImportSource, huggingfaceImportSource]
|
||||
|
||||
@@ -50,6 +58,33 @@ export function getNodeDisplayLabel(
|
||||
})
|
||||
}
|
||||
|
||||
function getModelComboWidget(
|
||||
model: MissingModelCandidate
|
||||
): { node: LGraphNode; widget: IBaseWidget } | null {
|
||||
if (model.nodeId == null) return null
|
||||
|
||||
const graph = app.rootGraph
|
||||
if (!graph) return null
|
||||
const node = getNodeByExecutionId(graph, String(model.nodeId))
|
||||
if (!node) return null
|
||||
|
||||
const widget = node.widgets?.find((w) => w.name === model.widgetName)
|
||||
if (!widget) return null
|
||||
|
||||
return { node, widget }
|
||||
}
|
||||
|
||||
export function getComboValue(
|
||||
model: MissingModelCandidate
|
||||
): string | undefined {
|
||||
const result = getModelComboWidget(model)
|
||||
if (!result) return undefined
|
||||
const val = result.widget.value
|
||||
if (typeof val === 'string') return val
|
||||
if (typeof val === 'number') return String(val)
|
||||
return undefined
|
||||
}
|
||||
|
||||
export function useMissingModelInteractions() {
|
||||
const { t } = useI18n()
|
||||
const store = useMissingModelStore()
|
||||
@@ -67,6 +102,30 @@ export function useMissingModelInteractions() {
|
||||
return store.modelExpandState[key] ?? false
|
||||
}
|
||||
|
||||
function getComboOptions(
|
||||
model: MissingModelCandidate
|
||||
): { name: string; value: string }[] {
|
||||
if (model.isAssetSupported && model.nodeType) {
|
||||
const assets = assetsStore.getAssets(model.nodeType) ?? []
|
||||
return assets.map((asset) => ({
|
||||
name: getAssetDisplayName(asset),
|
||||
value: getAssetFilename(asset)
|
||||
}))
|
||||
}
|
||||
|
||||
const result = getModelComboWidget(model)
|
||||
if (!result) return []
|
||||
const values = result.widget.options?.values
|
||||
if (!Array.isArray(values)) return []
|
||||
return values.map((v) => ({ name: String(v), value: String(v) }))
|
||||
}
|
||||
|
||||
function handleComboSelect(key: string, value: string | undefined) {
|
||||
if (value) {
|
||||
store.selectedLibraryModel[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
function isSelectionConfirmable(key: string): boolean {
|
||||
if (!store.selectedLibraryModel[key]) return false
|
||||
if (store.importCategoryMismatch[key]) return false
|
||||
@@ -84,7 +143,6 @@ export function useMissingModelInteractions() {
|
||||
function cancelLibrarySelect(key: string) {
|
||||
delete store.selectedLibraryModel[key]
|
||||
delete store.importCategoryMismatch[key]
|
||||
delete store.importTaskIds[key]
|
||||
}
|
||||
|
||||
/** Apply selected model to referencing nodes, removing only that model from the error list. */
|
||||
@@ -131,7 +189,6 @@ export function useMissingModelInteractions() {
|
||||
}
|
||||
|
||||
delete store.selectedLibraryModel[key]
|
||||
delete store.importTaskIds[key]
|
||||
const nodeIdSet = new Set(referencingNodes.map((ref) => String(ref.nodeId)))
|
||||
store.removeMissingModelByNameOnNodes(modelName, nodeIdSet)
|
||||
}
|
||||
@@ -250,16 +307,6 @@ export function useMissingModelInteractions() {
|
||||
}
|
||||
}
|
||||
|
||||
function handleUploadedModelImport(key: string, result: UploadModelSuccess) {
|
||||
if (result.taskId) {
|
||||
handleAsyncPending(key, result.taskId, result.modelType, result.filename)
|
||||
} else if (result.status === 'success') {
|
||||
handleAsyncCompleted(result.modelType)
|
||||
}
|
||||
|
||||
store.selectedLibraryModel[key] = result.filename
|
||||
}
|
||||
|
||||
function handleSyncResult(
|
||||
key: string,
|
||||
tags: string[],
|
||||
@@ -333,13 +380,14 @@ export function useMissingModelInteractions() {
|
||||
return {
|
||||
toggleModelExpand,
|
||||
isModelExpanded,
|
||||
getComboOptions,
|
||||
handleComboSelect,
|
||||
isSelectionConfirmable,
|
||||
cancelLibrarySelect,
|
||||
confirmLibrarySelect,
|
||||
handleUrlInput,
|
||||
getTypeMismatch,
|
||||
getDownloadStatus,
|
||||
handleUploadedModelImport,
|
||||
handleImport
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { createTestingPinia } from '@pinia/testing'
|
||||
import { fromAny } from '@total-typescript/shoehorn'
|
||||
import { render, screen, within } from '@testing-library/vue'
|
||||
import { render, screen } from '@testing-library/vue'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import PrimeVue from 'primevue/config'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
@@ -15,11 +15,8 @@ const i18n = createI18n({
|
||||
locale: 'en',
|
||||
messages: {
|
||||
en: {
|
||||
g: {
|
||||
nodesCount: '{count} node | {count} nodes'
|
||||
},
|
||||
rightSidePanel: {
|
||||
locateNode: 'Locate node on canvas',
|
||||
locateNode: 'Locate Node',
|
||||
missingNodePacks: {
|
||||
collapse: 'Collapse',
|
||||
expand: 'Expand'
|
||||
@@ -51,6 +48,7 @@ function makeGroup(overrides: Partial<SwapNodeGroup> = {}): SwapNodeGroup {
|
||||
function renderRow(
|
||||
props: Partial<{
|
||||
group: SwapNodeGroup
|
||||
showNodeIdBadge: boolean
|
||||
'onLocate-node': (nodeId: string) => void
|
||||
onReplace: (group: SwapNodeGroup) => void
|
||||
}> = {}
|
||||
@@ -58,6 +56,7 @@ function renderRow(
|
||||
return render(SwapNodeGroupRow, {
|
||||
props: {
|
||||
group: makeGroup(),
|
||||
showNodeIdBadge: false,
|
||||
...props
|
||||
},
|
||||
global: {
|
||||
@@ -76,15 +75,13 @@ describe('SwapNodeGroupRow', () => {
|
||||
expect(container.textContent).toContain('OldNodeType')
|
||||
})
|
||||
|
||||
it('renders node count as a badge', () => {
|
||||
renderRow()
|
||||
const badge = screen.getByLabelText('2 nodes')
|
||||
expect(badge).toBeInTheDocument()
|
||||
expect(within(badge).getByText('2')).toBeInTheDocument()
|
||||
it('renders node count in parentheses', () => {
|
||||
const { container } = renderRow()
|
||||
expect(container.textContent).toContain('(2)')
|
||||
})
|
||||
|
||||
it('renders node count of 5 for 5 nodeTypes', () => {
|
||||
renderRow({
|
||||
const { container } = renderRow({
|
||||
group: makeGroup({
|
||||
nodeTypes: Array.from({ length: 5 }, (_, i) => ({
|
||||
type: 'OldNodeType',
|
||||
@@ -93,9 +90,7 @@ describe('SwapNodeGroupRow', () => {
|
||||
}))
|
||||
})
|
||||
})
|
||||
const badge = screen.getByLabelText('5 nodes')
|
||||
expect(badge).toBeInTheDocument()
|
||||
expect(within(badge).getByText('5')).toBeInTheDocument()
|
||||
expect(container.textContent).toContain('(5)')
|
||||
})
|
||||
|
||||
it('renders the replacement target name', () => {
|
||||
@@ -120,147 +115,106 @@ describe('SwapNodeGroupRow', () => {
|
||||
|
||||
describe('Expand / Collapse', () => {
|
||||
it('starts collapsed — node list not visible', () => {
|
||||
renderRow()
|
||||
expect(
|
||||
screen.queryByRole('button', { name: 'Locate node on canvas' })
|
||||
).not.toBeInTheDocument()
|
||||
const { container } = renderRow({ showNodeIdBadge: true })
|
||||
expect(container.textContent).not.toContain('#1')
|
||||
})
|
||||
|
||||
it('expands when title is clicked', async () => {
|
||||
it('expands when chevron is clicked', async () => {
|
||||
const user = userEvent.setup()
|
||||
renderRow()
|
||||
await user.click(
|
||||
screen.getByRole('button', { name: 'Expand OldNodeType' })
|
||||
)
|
||||
expect(
|
||||
screen.getAllByRole('button', { name: 'Locate node on canvas' })
|
||||
).toHaveLength(2)
|
||||
const { container } = renderRow({ showNodeIdBadge: true })
|
||||
await user.click(screen.getByRole('button', { name: 'Expand' }))
|
||||
expect(container.textContent).toContain('#1')
|
||||
expect(container.textContent).toContain('#2')
|
||||
})
|
||||
|
||||
it('collapses when title is clicked again', async () => {
|
||||
it('collapses when chevron is clicked again', async () => {
|
||||
const user = userEvent.setup()
|
||||
renderRow()
|
||||
await user.click(
|
||||
screen.getByRole('button', { name: 'Expand OldNodeType' })
|
||||
)
|
||||
expect(
|
||||
screen.getAllByRole('button', { name: 'Locate node on canvas' })
|
||||
).toHaveLength(2)
|
||||
await user.click(
|
||||
screen.getByRole('button', { name: 'Collapse OldNodeType' })
|
||||
)
|
||||
expect(
|
||||
screen.queryByRole('button', { name: 'Locate node on canvas' })
|
||||
).not.toBeInTheDocument()
|
||||
const { container } = renderRow({ showNodeIdBadge: true })
|
||||
await user.click(screen.getByRole('button', { name: 'Expand' }))
|
||||
expect(container.textContent).toContain('#1')
|
||||
await user.click(screen.getByRole('button', { name: 'Collapse' }))
|
||||
expect(container.textContent).not.toContain('#1')
|
||||
})
|
||||
|
||||
it('updates the toggle control state when expanded', async () => {
|
||||
const user = userEvent.setup()
|
||||
renderRow()
|
||||
const titleButton = screen.getByRole('button', {
|
||||
name: 'Expand OldNodeType'
|
||||
})
|
||||
expect(titleButton).toHaveAttribute('aria-expanded', 'false')
|
||||
|
||||
await user.click(titleButton)
|
||||
|
||||
const collapseButton = screen.getByRole('button', {
|
||||
name: 'Collapse OldNodeType'
|
||||
})
|
||||
expect(collapseButton).toHaveAttribute('aria-expanded', 'true')
|
||||
expect(screen.getByRole('button', { name: 'Expand' })).toBeInTheDocument()
|
||||
await user.click(screen.getByRole('button', { name: 'Expand' }))
|
||||
expect(
|
||||
screen.getByRole('button', { name: 'Collapse' })
|
||||
).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Node Type List (Expanded)', () => {
|
||||
async function expand() {
|
||||
const user = userEvent.setup()
|
||||
await user.click(screen.getByRole('button', { name: /^Expand / }))
|
||||
await user.click(screen.getByRole('button', { name: 'Expand' }))
|
||||
}
|
||||
|
||||
it('renders all nodeTypes when expanded', async () => {
|
||||
renderRow({
|
||||
const { container } = renderRow({
|
||||
group: makeGroup({
|
||||
type: 'GroupedNodeType',
|
||||
nodeTypes: [
|
||||
{ type: 'GroupedNodeType', nodeId: '10', isReplaceable: true },
|
||||
{ type: 'GroupedNodeType', nodeId: '20', isReplaceable: true },
|
||||
{ type: 'GroupedNodeType', nodeId: '30', isReplaceable: true }
|
||||
{ type: 'OldNodeType', nodeId: '10', isReplaceable: true },
|
||||
{ type: 'OldNodeType', nodeId: '20', isReplaceable: true },
|
||||
{ type: 'OldNodeType', nodeId: '30', isReplaceable: true }
|
||||
]
|
||||
})
|
||||
}),
|
||||
showNodeIdBadge: true
|
||||
})
|
||||
expect(screen.queryByRole('list')).not.toBeInTheDocument()
|
||||
|
||||
await expand()
|
||||
expect(container.textContent).toContain('#10')
|
||||
expect(container.textContent).toContain('#20')
|
||||
expect(container.textContent).toContain('#30')
|
||||
})
|
||||
|
||||
expect(
|
||||
within(screen.getByRole('list')).getAllByRole('listitem')
|
||||
).toHaveLength(3)
|
||||
expect(
|
||||
within(screen.getByRole('list')).getAllByText('GroupedNodeType')
|
||||
).toHaveLength(3)
|
||||
it('shows nodeId badge when showNodeIdBadge is true', async () => {
|
||||
const { container } = renderRow({ showNodeIdBadge: true })
|
||||
await expand()
|
||||
expect(container.textContent).toContain('#1')
|
||||
expect(container.textContent).toContain('#2')
|
||||
})
|
||||
|
||||
it('hides nodeId badge when showNodeIdBadge is false', async () => {
|
||||
const { container } = renderRow({ showNodeIdBadge: false })
|
||||
await expand()
|
||||
expect(container.textContent).not.toContain('#1')
|
||||
expect(container.textContent).not.toContain('#2')
|
||||
})
|
||||
|
||||
it('renders Locate button for each nodeType with nodeId', async () => {
|
||||
renderRow()
|
||||
renderRow({ showNodeIdBadge: true })
|
||||
await expand()
|
||||
expect(
|
||||
screen.getAllByRole('button', { name: 'Locate node on canvas' })
|
||||
screen.getAllByRole('button', { name: 'Locate Node' })
|
||||
).toHaveLength(2)
|
||||
})
|
||||
|
||||
it('does not render Locate button for nodeTypes without nodeId', async () => {
|
||||
renderRow({
|
||||
group: makeGroup({
|
||||
// Intentionally omits nodeId to test graceful handling of incomplete node data
|
||||
nodeTypes: fromAny<MissingNodeType[], unknown>([
|
||||
{ type: 'NoIdNode', isReplaceable: true },
|
||||
{ type: 'OtherNoIdNode', isReplaceable: true }
|
||||
{ type: 'NoIdNode', isReplaceable: true }
|
||||
])
|
||||
})
|
||||
})
|
||||
await expand()
|
||||
expect(
|
||||
screen.queryByRole('button', { name: 'Locate node on canvas' })
|
||||
screen.queryByRole('button', { name: 'Locate Node' })
|
||||
).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('renders locate controls only for locatable nodeTypes', async () => {
|
||||
renderRow({
|
||||
group: makeGroup({
|
||||
type: 'MixedNodeType',
|
||||
nodeTypes: fromAny<MissingNodeType[], unknown>([
|
||||
{ type: 'MixedNodeType', nodeId: '10', isReplaceable: true },
|
||||
{ type: 'MixedNodeType', isReplaceable: true }
|
||||
])
|
||||
})
|
||||
})
|
||||
|
||||
await expand()
|
||||
|
||||
expect(
|
||||
within(screen.getByRole('list')).getAllByText('MixedNodeType')
|
||||
).toHaveLength(2)
|
||||
expect(
|
||||
within(screen.getByRole('list')).getAllByRole('button', {
|
||||
name: 'MixedNodeType'
|
||||
})
|
||||
).toHaveLength(1)
|
||||
expect(
|
||||
screen.getAllByRole('button', { name: 'Locate node on canvas' })
|
||||
).toHaveLength(1)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Events', () => {
|
||||
it('emits locate-node with correct nodeId', async () => {
|
||||
const onLocateNode = vi.fn()
|
||||
const user = userEvent.setup()
|
||||
renderRow({ 'onLocate-node': onLocateNode })
|
||||
await user.click(
|
||||
screen.getByRole('button', { name: 'Expand OldNodeType' })
|
||||
)
|
||||
const locateBtns = screen.getAllByRole('button', {
|
||||
name: 'Locate node on canvas'
|
||||
})
|
||||
renderRow({ showNodeIdBadge: true, 'onLocate-node': onLocateNode })
|
||||
await user.click(screen.getByRole('button', { name: 'Expand' }))
|
||||
const locateBtns = screen.getAllByRole('button', { name: 'Locate Node' })
|
||||
await user.click(locateBtns[0])
|
||||
expect(onLocateNode).toHaveBeenCalledWith('1')
|
||||
|
||||
@@ -279,100 +233,24 @@ describe('SwapNodeGroupRow', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('Single Node Groups', () => {
|
||||
it('locates a single node without expanding', async () => {
|
||||
const onLocateNode = vi.fn()
|
||||
const user = userEvent.setup()
|
||||
renderRow({
|
||||
group: makeGroup({
|
||||
type: 'SingleNodeType',
|
||||
nodeTypes: [
|
||||
{ type: 'SingleNodeType', nodeId: '42', isReplaceable: true }
|
||||
]
|
||||
}),
|
||||
'onLocate-node': onLocateNode
|
||||
})
|
||||
|
||||
expect(
|
||||
screen.queryByRole('button', { name: /^Expand / })
|
||||
).not.toBeInTheDocument()
|
||||
expect(screen.queryByLabelText('1 node')).not.toBeInTheDocument()
|
||||
|
||||
await user.click(screen.getByRole('button', { name: 'SingleNodeType' }))
|
||||
expect(onLocateNode).toHaveBeenCalledWith('42')
|
||||
|
||||
await user.click(
|
||||
screen.getByRole('button', { name: 'Locate node on canvas' })
|
||||
)
|
||||
expect(onLocateNode).toHaveBeenCalledTimes(2)
|
||||
expect(onLocateNode).toHaveBeenLastCalledWith('42')
|
||||
})
|
||||
|
||||
it('renders a single node without nodeId as non-locatable text', () => {
|
||||
renderRow({
|
||||
group: makeGroup({
|
||||
type: 'NoIdNode',
|
||||
nodeTypes: fromAny<MissingNodeType[], unknown>([
|
||||
{ type: 'NoIdNode', isReplaceable: true }
|
||||
])
|
||||
})
|
||||
})
|
||||
|
||||
expect(screen.getByText('NoIdNode')).toBeInTheDocument()
|
||||
expect(
|
||||
screen.queryByRole('button', { name: 'NoIdNode' })
|
||||
).not.toBeInTheDocument()
|
||||
expect(
|
||||
screen.queryByRole('button', { name: 'Locate node on canvas' })
|
||||
).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
it('handles empty nodeTypes array', () => {
|
||||
renderRow({
|
||||
group: makeGroup({
|
||||
nodeTypes: []
|
||||
})
|
||||
const { container } = renderRow({
|
||||
group: makeGroup({ nodeTypes: [] })
|
||||
})
|
||||
|
||||
expect(screen.getByText('OldNodeType')).toBeInTheDocument()
|
||||
expect(
|
||||
screen.queryByRole('button', { name: 'OldNodeType' })
|
||||
).not.toBeInTheDocument()
|
||||
expect(
|
||||
screen.queryByRole('button', { name: /^Expand / })
|
||||
).not.toBeInTheDocument()
|
||||
expect(
|
||||
screen.queryByRole('button', { name: 'Locate node on canvas' })
|
||||
).not.toBeInTheDocument()
|
||||
expect(container.textContent).toContain('(0)')
|
||||
})
|
||||
|
||||
it('handles string nodeType entries', async () => {
|
||||
const user = userEvent.setup()
|
||||
renderRow({
|
||||
const { container } = renderRow({
|
||||
group: makeGroup({
|
||||
nodeTypes: fromAny<MissingNodeType[], unknown>([
|
||||
'StringType',
|
||||
'OtherStringType'
|
||||
])
|
||||
// Intentionally uses a plain string entry to test legacy node type handling
|
||||
nodeTypes: fromAny<MissingNodeType[], unknown>(['StringType'])
|
||||
})
|
||||
})
|
||||
await user.click(
|
||||
screen.getByRole('button', { name: 'Expand OldNodeType' })
|
||||
)
|
||||
|
||||
expect(screen.getByText('StringType')).toBeInTheDocument()
|
||||
expect(screen.getByText('OtherStringType')).toBeInTheDocument()
|
||||
expect(
|
||||
screen.queryByRole('button', { name: 'StringType' })
|
||||
).not.toBeInTheDocument()
|
||||
expect(
|
||||
screen.queryByRole('button', { name: 'OtherStringType' })
|
||||
).not.toBeInTheDocument()
|
||||
expect(
|
||||
screen.queryByRole('button', { name: 'Locate node on canvas' })
|
||||
).not.toBeInTheDocument()
|
||||
await user.click(screen.getByRole('button', { name: 'Expand' }))
|
||||
expect(container.textContent).toContain('StringType')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,152 +1,100 @@
|
||||
<template>
|
||||
<div class="mb-1 flex w-full flex-col gap-0.5 last:mb-0">
|
||||
<div class="flex min-h-8 w-full items-center gap-1">
|
||||
<div class="mb-4 flex w-full flex-col">
|
||||
<!-- Type header row: type name + chevron -->
|
||||
<div class="flex h-8 w-full items-center">
|
||||
<p class="text-foreground min-w-0 flex-1 truncate text-sm font-medium">
|
||||
{{ `${group.type} (${group.nodeTypes.length})` }}
|
||||
</p>
|
||||
|
||||
<Button
|
||||
v-if="hasMultipleNodeTypes"
|
||||
variant="textonly"
|
||||
size="unset"
|
||||
size="icon-sm"
|
||||
:class="
|
||||
cn(
|
||||
'h-8 w-4 shrink-0 p-0 transition-transform duration-200 hover:bg-transparent',
|
||||
expanded && 'rotate-90'
|
||||
'size-8 shrink-0 transition-transform duration-200 hover:bg-transparent',
|
||||
{ 'rotate-180': expanded }
|
||||
)
|
||||
"
|
||||
aria-hidden="true"
|
||||
tabindex="-1"
|
||||
:aria-label="
|
||||
expanded
|
||||
? t('rightSidePanel.missingNodePacks.collapse', 'Collapse')
|
||||
: t('rightSidePanel.missingNodePacks.expand', 'Expand')
|
||||
"
|
||||
@click="toggleExpand"
|
||||
>
|
||||
<i
|
||||
aria-hidden="true"
|
||||
class="icon-[lucide--chevron-right] size-4 text-muted-foreground"
|
||||
class="icon-[lucide--chevron-down] size-4 text-muted-foreground group-hover:text-base-foreground"
|
||||
/>
|
||||
</Button>
|
||||
|
||||
<span class="flex min-w-0 flex-1 flex-col gap-0">
|
||||
<span class="flex min-w-0 items-center gap-2">
|
||||
<span class="flex min-w-0 items-center gap-2.5">
|
||||
<button
|
||||
v-if="hasMultipleNodeTypes"
|
||||
type="button"
|
||||
class="focus-visible:ring-ring m-0 inline max-w-full cursor-pointer appearance-none rounded-sm border-0 bg-transparent p-0 text-left text-sm/relaxed font-normal wrap-break-word text-base-foreground outline-none hover:text-base-foreground focus:outline-none focus-visible:underline focus-visible:ring-1 focus-visible:outline-none"
|
||||
:title="group.type"
|
||||
:aria-label="titleToggleAriaLabel"
|
||||
:aria-expanded="expanded"
|
||||
@click="toggleExpand"
|
||||
>
|
||||
{{ group.type }}
|
||||
</button>
|
||||
<button
|
||||
v-else-if="primaryLocatableNodeType"
|
||||
type="button"
|
||||
class="focus-visible:ring-ring m-0 inline max-w-full cursor-pointer appearance-none rounded-sm border-0 bg-transparent p-0 text-left text-sm/relaxed font-normal wrap-break-word text-base-foreground outline-none hover:text-base-foreground focus:outline-none focus-visible:underline focus-visible:ring-1 focus-visible:outline-none"
|
||||
:title="group.type"
|
||||
@click="handleLocateNode(primaryLocatableNodeType)"
|
||||
>
|
||||
{{ group.type }}
|
||||
</button>
|
||||
<span
|
||||
v-else
|
||||
class="min-w-0 truncate text-sm/relaxed font-normal text-base-foreground"
|
||||
:title="group.type"
|
||||
>
|
||||
{{ group.type }}
|
||||
</span>
|
||||
<span
|
||||
v-if="hasMultipleNodeTypes"
|
||||
data-testid="swap-node-group-count"
|
||||
role="img"
|
||||
class="flex size-6 shrink-0 items-center justify-center rounded-md bg-secondary-background-selected text-xs font-bold text-muted-foreground"
|
||||
:aria-label="t('g.nodesCount', group.nodeTypes.length)"
|
||||
>
|
||||
{{ group.nodeTypes.length }}
|
||||
</span>
|
||||
</span>
|
||||
</span>
|
||||
<span class="min-w-0 text-xs/relaxed text-muted-foreground">
|
||||
{{
|
||||
t(
|
||||
'nodeReplacement.willBeReplacedBy',
|
||||
'This node will be replaced by:'
|
||||
)
|
||||
}}
|
||||
<span
|
||||
class="inline-flex rounded-sm bg-modal-card-tag-background px-1.5 py-0.5 text-xs/none font-medium text-modal-card-tag-foreground"
|
||||
>
|
||||
{{ replacementLabel }}
|
||||
</span>
|
||||
</span>
|
||||
</span>
|
||||
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="sm"
|
||||
class="h-8 shrink-0 rounded-lg text-sm"
|
||||
@click="handleReplaceNode"
|
||||
>
|
||||
<i
|
||||
aria-hidden="true"
|
||||
class="text-foreground mr-1 icon-[lucide--repeat] size-4 shrink-0"
|
||||
/>
|
||||
<span class="text-foreground min-w-0 truncate">
|
||||
{{ t('nodeReplacement.replaceNode', 'Replace Node') }}
|
||||
</span>
|
||||
</Button>
|
||||
|
||||
<Button
|
||||
v-if="primaryLocatableNodeType"
|
||||
variant="textonly"
|
||||
size="icon-sm"
|
||||
class="size-8 shrink-0 text-muted-foreground hover:text-base-foreground"
|
||||
:aria-label="locateNodeLabel"
|
||||
@click="handleLocateNode(primaryLocatableNodeType)"
|
||||
>
|
||||
<i aria-hidden="true" class="icon-[lucide--locate] size-4" />
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<!-- Sub-labels: individual node instances, each with their own Locate button -->
|
||||
<TransitionCollapse>
|
||||
<ul v-if="expanded" class="m-0 list-none space-y-1 p-0 pl-5">
|
||||
<li
|
||||
v-for="(nodeType, index) in group.nodeTypes"
|
||||
:key="getKey(nodeType, index)"
|
||||
class="min-w-0"
|
||||
<div
|
||||
v-if="expanded"
|
||||
class="mb-2 flex flex-col gap-0.5 overflow-hidden pl-2"
|
||||
>
|
||||
<div
|
||||
v-for="nodeType in group.nodeTypes"
|
||||
:key="getKey(nodeType)"
|
||||
class="flex h-7 items-center"
|
||||
>
|
||||
<div class="flex min-w-0 items-center gap-2">
|
||||
<span class="flex min-w-0 flex-1 items-center gap-1">
|
||||
<button
|
||||
v-if="isLocatableNodeType(nodeType)"
|
||||
type="button"
|
||||
class="focus-visible:ring-ring m-0 inline max-w-full cursor-pointer appearance-none rounded-sm border-0 bg-transparent p-0 text-left text-sm/relaxed font-normal wrap-break-word text-muted-foreground outline-none hover:text-base-foreground focus:outline-none focus-visible:underline focus-visible:ring-1 focus-visible:outline-none"
|
||||
@click="handleLocateNode(nodeType)"
|
||||
>
|
||||
{{ getLabel(nodeType) }}
|
||||
</button>
|
||||
<span
|
||||
v-else
|
||||
class="text-sm/relaxed wrap-break-word text-muted-foreground"
|
||||
>
|
||||
{{ getLabel(nodeType) }}
|
||||
</span>
|
||||
</span>
|
||||
<Button
|
||||
v-if="isLocatableNodeType(nodeType)"
|
||||
variant="textonly"
|
||||
size="icon-sm"
|
||||
class="size-8 shrink-0 text-muted-foreground hover:text-base-foreground"
|
||||
:aria-label="locateNodeLabel"
|
||||
@click="handleLocateNode(nodeType)"
|
||||
>
|
||||
<i aria-hidden="true" class="icon-[lucide--locate] size-4" />
|
||||
</Button>
|
||||
</div>
|
||||
</li>
|
||||
</ul>
|
||||
<span
|
||||
v-if="
|
||||
showNodeIdBadge &&
|
||||
typeof nodeType !== 'string' &&
|
||||
nodeType.nodeId != null
|
||||
"
|
||||
class="mr-1 shrink-0 rounded-md bg-secondary-background-selected px-2 py-0.5 font-mono text-xs font-bold text-muted-foreground"
|
||||
>
|
||||
#{{ nodeType.nodeId }}
|
||||
</span>
|
||||
<p class="min-w-0 flex-1 truncate text-xs text-muted-foreground">
|
||||
{{ getLabel(nodeType) }}
|
||||
</p>
|
||||
<Button
|
||||
v-if="typeof nodeType !== 'string' && nodeType.nodeId != null"
|
||||
variant="textonly"
|
||||
size="icon-sm"
|
||||
class="mr-1 size-6 shrink-0 text-muted-foreground hover:text-base-foreground"
|
||||
:aria-label="t('rightSidePanel.locateNode', 'Locate Node')"
|
||||
@click="handleLocateNode(nodeType)"
|
||||
>
|
||||
<i class="icon-[lucide--locate] size-3" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</TransitionCollapse>
|
||||
|
||||
<!-- Description rows: what it is replaced by -->
|
||||
<div class="mt-1 mb-2 flex flex-col gap-0.5 px-1 text-[13px]">
|
||||
<span class="text-muted-foreground">{{
|
||||
t('nodeReplacement.willBeReplacedBy', 'This node will be replaced by:')
|
||||
}}</span>
|
||||
<span class="text-foreground font-bold">{{
|
||||
group.newNodeId ?? t('nodeReplacement.unknownNode', 'Unknown')
|
||||
}}</span>
|
||||
</div>
|
||||
|
||||
<!-- Replace Action Button -->
|
||||
<div class="flex w-full items-start py-1">
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="md"
|
||||
class="flex w-full flex-1"
|
||||
@click="handleReplaceNode"
|
||||
>
|
||||
<i class="text-foreground mr-1 icon-[lucide--repeat] size-4 shrink-0" />
|
||||
<span class="text-foreground min-w-0 truncate text-sm">
|
||||
{{ t('nodeReplacement.replaceNode', 'Replace Node') }}
|
||||
</span>
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed, ref } from 'vue'
|
||||
import { ref } from 'vue'
|
||||
import { cn } from '@comfyorg/tailwind-utils'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import Button from '@/components/ui/button/Button.vue'
|
||||
@@ -154,8 +102,9 @@ import TransitionCollapse from '@/components/rightSidePanel/layout/TransitionCol
|
||||
import type { MissingNodeType } from '@/types/comfy'
|
||||
import type { SwapNodeGroup } from '@/components/rightSidePanel/errors/useErrorGroups'
|
||||
|
||||
const { group } = defineProps<{
|
||||
const props = defineProps<{
|
||||
group: SwapNodeGroup
|
||||
showNodeIdBadge: boolean
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
@@ -166,54 +115,28 @@ const emit = defineEmits<{
|
||||
const { t } = useI18n()
|
||||
|
||||
const expanded = ref(false)
|
||||
const hasMultipleNodeTypes = computed(() => group.nodeTypes.length > 1)
|
||||
const replacementLabel = computed(
|
||||
() => group.newNodeId ?? t('nodeReplacement.unknownNode', 'Unknown')
|
||||
)
|
||||
const locateNodeLabel = computed(() =>
|
||||
t('rightSidePanel.locateNode', 'Locate node on canvas')
|
||||
)
|
||||
const titleToggleAriaLabel = computed(
|
||||
() =>
|
||||
`${
|
||||
expanded.value
|
||||
? t('rightSidePanel.missingNodePacks.collapse', 'Collapse')
|
||||
: t('rightSidePanel.missingNodePacks.expand', 'Expand')
|
||||
} ${group.type}`
|
||||
)
|
||||
const primaryLocatableNodeType = computed(() => {
|
||||
if (group.nodeTypes.length !== 1) return null
|
||||
const [nodeType] = group.nodeTypes
|
||||
return isLocatableNodeType(nodeType) ? nodeType : null
|
||||
})
|
||||
|
||||
function toggleExpand() {
|
||||
expanded.value = !expanded.value
|
||||
}
|
||||
|
||||
function getKey(nodeType: MissingNodeType, index: number): string {
|
||||
if (typeof nodeType === 'string') return `${nodeType}-${index}`
|
||||
return nodeType.nodeId != null
|
||||
? String(nodeType.nodeId)
|
||||
: `${nodeType.type}-${index}`
|
||||
function getKey(nodeType: MissingNodeType): string {
|
||||
if (typeof nodeType === 'string') return nodeType
|
||||
return nodeType.nodeId != null ? String(nodeType.nodeId) : nodeType.type
|
||||
}
|
||||
|
||||
function getLabel(nodeType: MissingNodeType): string {
|
||||
return typeof nodeType === 'string' ? nodeType : nodeType.type
|
||||
}
|
||||
|
||||
function isLocatableNodeType(
|
||||
nodeType: MissingNodeType
|
||||
): nodeType is Exclude<MissingNodeType, string> & { nodeId: string | number } {
|
||||
return typeof nodeType !== 'string' && nodeType.nodeId != null
|
||||
}
|
||||
|
||||
function handleLocateNode(nodeType: MissingNodeType) {
|
||||
if (!isLocatableNodeType(nodeType)) return
|
||||
emit('locate-node', String(nodeType.nodeId))
|
||||
if (typeof nodeType === 'string') return
|
||||
if (nodeType.nodeId != null) {
|
||||
emit('locate-node', String(nodeType.nodeId))
|
||||
}
|
||||
}
|
||||
|
||||
function handleReplaceNode() {
|
||||
emit('replace', group)
|
||||
emit('replace', props.group)
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -10,8 +10,8 @@ vi.mock('./SwapNodeGroupRow.vue', () => ({
|
||||
default: {
|
||||
name: 'SwapNodeGroupRow',
|
||||
template:
|
||||
'<div class="swap-row" :data-group-type="group?.type"><button class="locate-trigger" @click="$emit(\'locate-node\', group?.nodeTypes?.[0]?.nodeId)">Locate</button><button class="replace-trigger" @click="$emit(\'replace\', group)">Replace</button></div>',
|
||||
props: ['group'],
|
||||
'<div class="swap-row" :data-show-node-id-badge="showNodeIdBadge" :data-group-type="group?.type"><button class="locate-trigger" @click="$emit(\'locate-node\', group?.nodeTypes?.[0]?.nodeId)">Locate</button><button class="replace-trigger" @click="$emit(\'replace\', group)">Replace</button></div>',
|
||||
props: ['group', 'showNodeIdBadge'],
|
||||
emits: ['locate-node', 'replace']
|
||||
}
|
||||
}))
|
||||
@@ -29,6 +29,7 @@ function makeGroups(count = 2): SwapNodeGroup[] {
|
||||
function mountCard(
|
||||
props: Partial<{
|
||||
swapNodeGroups: SwapNodeGroup[]
|
||||
showNodeIdBadge: boolean
|
||||
}> = {},
|
||||
callbacks?: {
|
||||
onLocateNode?: (nodeId: string) => void
|
||||
@@ -38,6 +39,7 @@ function mountCard(
|
||||
return render(SwapNodesCard, {
|
||||
props: {
|
||||
swapNodeGroups: makeGroups(),
|
||||
showNodeIdBadge: false,
|
||||
...props,
|
||||
...(callbacks?.onLocateNode
|
||||
? { 'onLocate-node': callbacks.onLocateNode }
|
||||
@@ -70,6 +72,16 @@ describe('SwapNodesCard', () => {
|
||||
expect(container.querySelectorAll('.swap-row')).toHaveLength(1)
|
||||
})
|
||||
|
||||
it('passes showNodeIdBadge to children', () => {
|
||||
const { container } = mountCard({
|
||||
swapNodeGroups: makeGroups(1),
|
||||
showNodeIdBadge: true
|
||||
})
|
||||
// eslint-disable-next-line testing-library/no-container, testing-library/no-node-access
|
||||
const row = container.querySelector('.swap-row')
|
||||
expect(row!.getAttribute('data-show-node-id-badge')).toBe('true')
|
||||
})
|
||||
|
||||
it('passes group prop to children', () => {
|
||||
const groups = makeGroups(1)
|
||||
const { container } = mountCard({ swapNodeGroups: groups })
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
v-for="group in swapNodeGroups"
|
||||
:key="group.type"
|
||||
:group="group"
|
||||
:show-node-id-badge="showNodeIdBadge"
|
||||
@locate-node="emit('locate-node', $event)"
|
||||
@replace="emit('replace', $event)"
|
||||
/>
|
||||
@@ -14,8 +15,9 @@
|
||||
import type { SwapNodeGroup } from '@/components/rightSidePanel/errors/useErrorGroups'
|
||||
import SwapNodeGroupRow from '@/platform/nodeReplacement/components/SwapNodeGroupRow.vue'
|
||||
|
||||
const { swapNodeGroups } = defineProps<{
|
||||
const { swapNodeGroups, showNodeIdBadge } = defineProps<{
|
||||
swapNodeGroups: SwapNodeGroup[]
|
||||
showNodeIdBadge: boolean
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
|
||||
@@ -53,8 +53,7 @@ describe('useReleaseService', () => {
|
||||
project: 'comfyui',
|
||||
current_version: '1.0.0'
|
||||
},
|
||||
signal: undefined,
|
||||
headers: undefined
|
||||
signal: undefined
|
||||
})
|
||||
|
||||
expect(result).toEqual(mockReleases)
|
||||
@@ -77,8 +76,7 @@ describe('useReleaseService', () => {
|
||||
current_version: '1.0.0',
|
||||
form_factor: 'desktop-windows'
|
||||
},
|
||||
signal: undefined,
|
||||
headers: undefined
|
||||
signal: undefined
|
||||
})
|
||||
|
||||
expect(result).toEqual(mockReleases)
|
||||
@@ -88,30 +86,11 @@ describe('useReleaseService', () => {
|
||||
const abortController = new AbortController()
|
||||
mockAxiosInstance.get.mockResolvedValue({ data: mockReleases })
|
||||
|
||||
await service.getReleases(
|
||||
{ project: 'comfyui' },
|
||||
{ signal: abortController.signal }
|
||||
)
|
||||
await service.getReleases({ project: 'comfyui' }, abortController.signal)
|
||||
|
||||
expect(mockAxiosInstance.get).toHaveBeenCalledWith('/releases', {
|
||||
params: { project: 'comfyui' },
|
||||
signal: abortController.signal,
|
||||
headers: undefined
|
||||
})
|
||||
})
|
||||
|
||||
it('should send Comfy-Env header when deployEnvironment is provided', async () => {
|
||||
mockAxiosInstance.get.mockResolvedValue({ data: mockReleases })
|
||||
|
||||
await service.getReleases(
|
||||
{ project: 'comfyui' },
|
||||
{ deployEnvironment: 'local-desktop' }
|
||||
)
|
||||
|
||||
expect(mockAxiosInstance.get).toHaveBeenCalledWith('/releases', {
|
||||
params: { project: 'comfyui' },
|
||||
signal: undefined,
|
||||
headers: { 'Comfy-Env': 'local-desktop' }
|
||||
signal: abortController.signal
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -98,9 +98,8 @@ export const useReleaseService = () => {
|
||||
// Fetch release notes from API
|
||||
const getReleases = async (
|
||||
params: GetReleasesParams,
|
||||
options: { signal?: AbortSignal; deployEnvironment?: string } = {}
|
||||
signal?: AbortSignal
|
||||
): Promise<ReleaseNote[] | null> => {
|
||||
const { signal, deployEnvironment } = options
|
||||
const endpoint = '/releases'
|
||||
const errorContext = 'Failed to get releases'
|
||||
const routeSpecificErrors = {
|
||||
@@ -111,10 +110,7 @@ export const useReleaseService = () => {
|
||||
() =>
|
||||
releaseApiClient.get<ReleaseNote[]>(endpoint, {
|
||||
params,
|
||||
signal,
|
||||
headers: deployEnvironment
|
||||
? { 'Comfy-Env': deployEnvironment }
|
||||
: undefined
|
||||
signal
|
||||
}),
|
||||
errorContext,
|
||||
routeSpecificErrors
|
||||
|
||||
@@ -228,15 +228,12 @@ describe('useReleaseStore', () => {
|
||||
|
||||
await store.initialize()
|
||||
|
||||
expect(releaseService.getReleases).toHaveBeenCalledWith(
|
||||
{
|
||||
project: 'comfyui',
|
||||
current_version: '1.0.0',
|
||||
form_factor: 'git-windows',
|
||||
locale: 'en'
|
||||
},
|
||||
{ deployEnvironment: undefined }
|
||||
)
|
||||
expect(releaseService.getReleases).toHaveBeenCalledWith({
|
||||
project: 'comfyui',
|
||||
current_version: '1.0.0',
|
||||
form_factor: 'git-windows',
|
||||
locale: 'en'
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -303,15 +300,12 @@ describe('useReleaseStore', () => {
|
||||
|
||||
await store.initialize()
|
||||
|
||||
expect(releaseService.getReleases).toHaveBeenCalledWith(
|
||||
{
|
||||
project: 'comfyui',
|
||||
current_version: '1.0.0',
|
||||
form_factor: 'git-windows',
|
||||
locale: 'en'
|
||||
},
|
||||
{ deployEnvironment: undefined }
|
||||
)
|
||||
expect(releaseService.getReleases).toHaveBeenCalledWith({
|
||||
project: 'comfyui',
|
||||
current_version: '1.0.0',
|
||||
form_factor: 'git-windows',
|
||||
locale: 'en'
|
||||
})
|
||||
expect(store.releases).toEqual([mockRelease])
|
||||
})
|
||||
|
||||
@@ -324,30 +318,12 @@ describe('useReleaseStore', () => {
|
||||
|
||||
await store.initialize()
|
||||
|
||||
expect(releaseService.getReleases).toHaveBeenCalledWith(
|
||||
{
|
||||
project: 'comfyui',
|
||||
current_version: '1.0.0',
|
||||
form_factor: 'desktop-mac',
|
||||
locale: 'en'
|
||||
},
|
||||
{ deployEnvironment: undefined }
|
||||
)
|
||||
})
|
||||
|
||||
it('should pass deploy_environment from system stats', async () => {
|
||||
const store = useReleaseStore()
|
||||
const releaseService = useReleaseService()
|
||||
const systemStatsStore = useSystemStatsStore()
|
||||
systemStatsStore.systemStats!.system.deploy_environment = 'local-desktop'
|
||||
vi.mocked(releaseService.getReleases).mockResolvedValue([mockRelease])
|
||||
|
||||
await store.initialize()
|
||||
|
||||
expect(releaseService.getReleases).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
{ deployEnvironment: 'local-desktop' }
|
||||
)
|
||||
expect(releaseService.getReleases).toHaveBeenCalledWith({
|
||||
project: 'comfyui',
|
||||
current_version: '1.0.0',
|
||||
form_factor: 'desktop-mac',
|
||||
locale: 'en'
|
||||
})
|
||||
})
|
||||
|
||||
it('should skip fetching when --disable-api-nodes is present', async () => {
|
||||
|
||||
@@ -266,18 +266,12 @@ export const useReleaseStore = defineStore('release', () => {
|
||||
await until(systemStatsStore.isInitialized)
|
||||
}
|
||||
|
||||
const fetchedReleases = await releaseService.getReleases(
|
||||
{
|
||||
project: isCloud ? 'cloud' : 'comfyui',
|
||||
current_version: currentVersion.value,
|
||||
form_factor: systemStatsStore.getFormFactor(),
|
||||
locale: stringToLocale(locale.value)
|
||||
},
|
||||
{
|
||||
deployEnvironment:
|
||||
systemStatsStore.systemStats?.system?.deploy_environment
|
||||
}
|
||||
)
|
||||
const fetchedReleases = await releaseService.getReleases({
|
||||
project: isCloud ? 'cloud' : 'comfyui',
|
||||
current_version: currentVersion.value,
|
||||
form_factor: systemStatsStore.getFormFactor(),
|
||||
locale: stringToLocale(locale.value)
|
||||
})
|
||||
|
||||
if (fetchedReleases !== null) {
|
||||
releases.value = fetchedReleases
|
||||
|
||||
@@ -1,94 +0,0 @@
|
||||
import { createTestingPinia } from '@pinia/testing'
|
||||
import { render, screen } from '@testing-library/vue'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
import { ref } from 'vue'
|
||||
import { createI18n } from 'vue-i18n'
|
||||
|
||||
import WorkspaceSwitcherPopover from './WorkspaceSwitcherPopover.vue'
|
||||
|
||||
vi.mock('@/platform/workspace/composables/useWorkspaceSwitch', () => ({
|
||||
useWorkspaceSwitch: () => ({ switchWorkspace: vi.fn() })
|
||||
}))
|
||||
|
||||
vi.mock('@/composables/billing/useBillingContext', () => ({
|
||||
useBillingContext: () => ({ subscription: ref(null) })
|
||||
}))
|
||||
|
||||
const LONG_WORKSPACE_NAME =
|
||||
'Quantum Renaissance Collective for Hyperdimensional Latent Diffusion Research and Experimental Workflow Engineering'
|
||||
|
||||
const i18n = createI18n({
|
||||
legacy: false,
|
||||
locale: 'en',
|
||||
messages: {
|
||||
en: {
|
||||
workspaceSwitcher: {
|
||||
personal: 'Personal',
|
||||
roleOwner: 'Owner',
|
||||
roleMember: 'Member',
|
||||
createWorkspace: 'Create new workspace',
|
||||
maxWorkspacesReached:
|
||||
'You can only own 10 workspaces. Delete one to create a new one.'
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
function createWorkspaceState(overrides: Record<string, unknown>) {
|
||||
return {
|
||||
created_at: '2026-01-01T00:00:00Z',
|
||||
joined_at: '2026-01-01T00:00:00Z',
|
||||
isSubscribed: false,
|
||||
subscriptionPlan: null,
|
||||
subscriptionTier: null,
|
||||
members: [],
|
||||
pendingInvites: [],
|
||||
...overrides
|
||||
}
|
||||
}
|
||||
|
||||
function renderComponent() {
|
||||
return render(WorkspaceSwitcherPopover, {
|
||||
global: {
|
||||
plugins: [
|
||||
createTestingPinia({
|
||||
createSpy: vi.fn,
|
||||
initialState: {
|
||||
teamWorkspace: {
|
||||
activeWorkspaceId: 'ws-personal',
|
||||
isFetchingWorkspaces: false,
|
||||
workspaces: [
|
||||
createWorkspaceState({
|
||||
id: 'ws-personal',
|
||||
name: 'Personal Workspace',
|
||||
type: 'personal',
|
||||
role: 'owner'
|
||||
}),
|
||||
createWorkspaceState({
|
||||
id: 'ws-team-long',
|
||||
name: LONG_WORKSPACE_NAME,
|
||||
type: 'team',
|
||||
role: 'member'
|
||||
})
|
||||
]
|
||||
}
|
||||
}
|
||||
}),
|
||||
i18n
|
||||
],
|
||||
stubs: {
|
||||
WorkspaceProfilePic: true
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
describe('WorkspaceSwitcherPopover', () => {
|
||||
it('exposes the full team workspace name as a tooltip on the row', () => {
|
||||
renderComponent()
|
||||
|
||||
const name = screen.getByText(LONG_WORKSPACE_NAME)
|
||||
|
||||
expect(name).toHaveAttribute('title', LONG_WORKSPACE_NAME)
|
||||
})
|
||||
})
|
||||
@@ -34,20 +34,21 @@
|
||||
@click="handleSelectWorkspace(workspace)"
|
||||
>
|
||||
<WorkspaceProfilePic
|
||||
class="size-8 shrink-0 text-sm"
|
||||
class="size-8 text-sm"
|
||||
:workspace-name="workspace.name"
|
||||
/>
|
||||
<div class="flex min-w-0 flex-1 flex-col items-start gap-1">
|
||||
<div class="flex max-w-full items-center gap-1.5">
|
||||
<span
|
||||
:title="getDisplayName(workspace)"
|
||||
class="truncate text-sm text-base-foreground"
|
||||
>
|
||||
{{ getDisplayName(workspace) }}
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span class="text-sm text-base-foreground">
|
||||
{{
|
||||
workspace.type === 'personal'
|
||||
? $t('workspaceSwitcher.personal')
|
||||
: workspace.name
|
||||
}}
|
||||
</span>
|
||||
<span
|
||||
v-if="resolveTierLabel(workspace)"
|
||||
class="shrink-0 rounded-full bg-base-foreground px-1 py-0.5 text-2xs font-bold text-base-background uppercase"
|
||||
class="rounded-full bg-base-foreground px-1 py-0.5 text-2xs font-bold text-base-background uppercase"
|
||||
>
|
||||
{{ resolveTierLabel(workspace) }}
|
||||
</span>
|
||||
@@ -58,7 +59,7 @@
|
||||
</div>
|
||||
<i
|
||||
v-if="isCurrentWorkspace(workspace)"
|
||||
class="pi pi-check shrink-0 text-sm text-base-foreground"
|
||||
class="pi pi-check text-sm text-base-foreground"
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
@@ -170,12 +171,6 @@ function isCurrentWorkspace(workspace: AvailableWorkspace): boolean {
|
||||
return workspace.id === workspaceId.value
|
||||
}
|
||||
|
||||
function getDisplayName(workspace: AvailableWorkspace): string {
|
||||
return workspace.type === 'personal'
|
||||
? t('workspaceSwitcher.personal')
|
||||
: workspace.name
|
||||
}
|
||||
|
||||
function getRoleLabel(role: AvailableWorkspace['role']): string {
|
||||
if (role === 'owner') return t('workspaceSwitcher.roleOwner')
|
||||
if (role === 'member') return t('workspaceSwitcher.roleMember')
|
||||
|
||||
@@ -33,10 +33,6 @@ const WorkspaceTokenResponseSchema = z.object({
|
||||
permissions: z.array(z.string())
|
||||
})
|
||||
|
||||
export type WorkspaceTokenResponse = z.infer<
|
||||
typeof WorkspaceTokenResponseSchema
|
||||
>
|
||||
|
||||
export class WorkspaceAuthError extends Error {
|
||||
constructor(
|
||||
message: string,
|
||||
|
||||
@@ -0,0 +1,95 @@
|
||||
import { render, screen } from '@testing-library/vue'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { ComfyNodeDef as ComfyNodeDefV2 } from '@/schemas/nodeDef/nodeDefSchemaV2'
|
||||
import LGraphNodePreview from '@/renderer/extensions/vueNodes/components/LGraphNodePreview.vue'
|
||||
import { fromPartial } from '@total-typescript/shoehorn'
|
||||
|
||||
vi.mock('@/stores/widgetStore', () => ({
|
||||
useWidgetStore: () => ({ inputIsWidget: () => true })
|
||||
}))
|
||||
|
||||
// Serializes the nodeData prop so tests can assert on the data contract
|
||||
// LGraphNodePreview hands to NodeWidgets. How that data renders is covered
|
||||
// by NodeWidgets.test.ts and browser_tests/tests/sidebar/modelLibrary.spec.ts.
|
||||
const NodeWidgetsProbe = {
|
||||
props: ['nodeData'],
|
||||
template: '<div data-testid="node-data">{{ JSON.stringify(nodeData) }}</div>'
|
||||
}
|
||||
|
||||
interface ProbedWidget {
|
||||
name: string
|
||||
value?: unknown
|
||||
options?: { values?: string[] }
|
||||
}
|
||||
|
||||
const nodeDef = fromPartial<ComfyNodeDefV2>({
|
||||
name: 'CheckpointLoaderSimple',
|
||||
display_name: 'Load Checkpoint',
|
||||
inputs: {
|
||||
ckpt_name: { type: 'COMBO', options: ['a.safetensors', 'b.safetensors'] }
|
||||
},
|
||||
outputs: []
|
||||
})
|
||||
|
||||
function renderedWidgets(
|
||||
def: ComfyNodeDefV2,
|
||||
props: { widgetValues?: Record<string, string> } = {}
|
||||
) {
|
||||
render(LGraphNodePreview, {
|
||||
props: { nodeDef: def, ...props },
|
||||
global: {
|
||||
stubs: {
|
||||
NodeHeader: true,
|
||||
NodeSlots: true,
|
||||
NodeWidgets: NodeWidgetsProbe
|
||||
}
|
||||
}
|
||||
})
|
||||
const nodeData: { widgets?: ProbedWidget[] } = JSON.parse(
|
||||
screen.getByTestId('node-data').textContent ?? ''
|
||||
)
|
||||
return nodeData.widgets ?? []
|
||||
}
|
||||
|
||||
function renderedComboWidget(
|
||||
props: { widgetValues?: Record<string, string> } = {}
|
||||
) {
|
||||
return renderedWidgets(nodeDef, props).find((w) => w.name === 'ckpt_name')
|
||||
}
|
||||
|
||||
describe('LGraphNodePreview', () => {
|
||||
it('leads the combo options with the provided widget value', () => {
|
||||
const widget = renderedComboWidget({
|
||||
widgetValues: { ckpt_name: 'sd_xl_base_1.0.safetensors' }
|
||||
})
|
||||
|
||||
expect(widget?.options?.values).toEqual([
|
||||
'sd_xl_base_1.0.safetensors',
|
||||
'a.safetensors',
|
||||
'b.safetensors'
|
||||
])
|
||||
})
|
||||
|
||||
it('keeps the combo options untouched when no value is provided', () => {
|
||||
const widget = renderedComboWidget()
|
||||
|
||||
expect(widget?.options?.values).toEqual(['a.safetensors', 'b.safetensors'])
|
||||
})
|
||||
|
||||
it('uses the input default when defined and empty string otherwise', () => {
|
||||
const widgets = renderedWidgets(
|
||||
fromPartial<ComfyNodeDefV2>({
|
||||
name: 'TestNode',
|
||||
inputs: {
|
||||
steps: { type: 'INT', default: 20 },
|
||||
text: { type: 'STRING' }
|
||||
},
|
||||
outputs: []
|
||||
})
|
||||
)
|
||||
|
||||
expect(widgets.find((w) => w.name === 'steps')?.value).toBe(20)
|
||||
expect(widgets.find((w) => w.name === 'text')?.value).toBe('')
|
||||
})
|
||||
})
|
||||
@@ -45,9 +45,14 @@ import type { ComfyNodeDef as ComfyNodeDefV2 } from '@/schemas/nodeDef/nodeDefSc
|
||||
import { useWidgetStore } from '@/stores/widgetStore'
|
||||
import { cn } from '@comfyorg/tailwind-utils'
|
||||
|
||||
const { nodeDef, position = 'absolute' } = defineProps<{
|
||||
const {
|
||||
nodeDef,
|
||||
position = 'absolute',
|
||||
widgetValues
|
||||
} = defineProps<{
|
||||
nodeDef: ComfyNodeDefV2
|
||||
position?: 'absolute' | 'relative'
|
||||
widgetValues?: Record<string, string>
|
||||
}>()
|
||||
|
||||
const widgetStore = useWidgetStore()
|
||||
@@ -56,27 +61,32 @@ const widgetStore = useWidgetStore()
|
||||
const nodeData = computed<VueNodeData>(() => {
|
||||
const widgets = Object.entries(nodeDef.inputs || {})
|
||||
.filter(([_, input]) => widgetStore.inputIsWidget(input))
|
||||
.map(([name, input]) => ({
|
||||
nodeId: '-1',
|
||||
name,
|
||||
type: input.widgetType || input.type,
|
||||
value:
|
||||
input.default !== undefined
|
||||
? input.default
|
||||
: input.type === 'COMBO' &&
|
||||
Array.isArray(input.options) &&
|
||||
input.options.length > 0
|
||||
? input.options[0]
|
||||
: '',
|
||||
options: {
|
||||
hidden: input.hidden,
|
||||
advanced: input.advanced,
|
||||
values:
|
||||
input.type === 'COMBO' && Array.isArray(input.options)
|
||||
? input.options
|
||||
: undefined
|
||||
} satisfies IWidgetOptions
|
||||
}))
|
||||
.map(([name, input]) => {
|
||||
const comboValues =
|
||||
input.type === 'COMBO' && Array.isArray(input.options)
|
||||
? input.options
|
||||
: undefined
|
||||
// Preview nodes have no widget-value store entry, so combo widgets
|
||||
// render their first option; lead with the requested value to show it.
|
||||
const leadValue = widgetValues?.[name]
|
||||
return {
|
||||
nodeId: '-1',
|
||||
name,
|
||||
type: input.widgetType || input.type,
|
||||
value:
|
||||
input.default !== undefined
|
||||
? input.default
|
||||
: (comboValues?.[0] ?? ''),
|
||||
options: {
|
||||
hidden: input.hidden,
|
||||
advanced: input.advanced,
|
||||
values:
|
||||
leadValue && comboValues
|
||||
? [leadValue, ...comboValues.filter((o) => o !== leadValue)]
|
||||
: comboValues
|
||||
} satisfies IWidgetOptions
|
||||
}
|
||||
})
|
||||
|
||||
const inputs: INodeInputSlot[] = Object.entries(nodeDef.inputs || {})
|
||||
.filter(([_, input]) => !widgetStore.inputIsWidget(input))
|
||||
|
||||
@@ -51,9 +51,6 @@ const AudioPreviewPlayer = defineAsyncComponent(
|
||||
const Load3D = defineAsyncComponent(
|
||||
() => import('@/components/load3d/Load3D.vue')
|
||||
)
|
||||
const Load3DAdvanced = defineAsyncComponent(
|
||||
() => import('@/components/load3d/Load3DAdvanced.vue')
|
||||
)
|
||||
const WidgetImageCrop = defineAsyncComponent(
|
||||
() => import('@/components/imagecrop/WidgetImageCrop.vue')
|
||||
)
|
||||
@@ -172,14 +169,6 @@ const coreWidgetDefinitions: Array<[string, WidgetDefinition]> = [
|
||||
}
|
||||
],
|
||||
['load3D', { component: Load3D, aliases: ['LOAD_3D'], essential: false }],
|
||||
[
|
||||
'load3DAdvanced',
|
||||
{
|
||||
component: Load3DAdvanced,
|
||||
aliases: ['LOAD_3D_ADVANCED'],
|
||||
essential: false
|
||||
}
|
||||
],
|
||||
[
|
||||
'imagecrop',
|
||||
{
|
||||
@@ -254,7 +243,6 @@ const EXPANDING_TYPES = [
|
||||
'textarea',
|
||||
'markdown',
|
||||
'load3D',
|
||||
'load3DAdvanced',
|
||||
'curve',
|
||||
'painter',
|
||||
'imagecompare',
|
||||
|
||||
@@ -252,7 +252,6 @@ const zSystemStats = z.object({
|
||||
python_version: z.string(),
|
||||
embedded_python: z.boolean(),
|
||||
comfyui_version: z.string(),
|
||||
deploy_environment: z.string().optional(),
|
||||
pytorch_version: z.string(),
|
||||
required_frontend_version: z.string().optional(),
|
||||
argv: z.array(z.string()),
|
||||
|
||||
@@ -108,11 +108,6 @@ interface QueuePromptRequestBody {
|
||||
* ```
|
||||
*/
|
||||
api_key_comfy_org?: string
|
||||
/**
|
||||
* Identifies the client submitting the prompt. Forwarded by the backend
|
||||
* to API nodes' upstream requests via the Comfy-Usage-Source header.
|
||||
*/
|
||||
comfy_usage_source?: string
|
||||
/**
|
||||
* Override the preview method for this prompt execution.
|
||||
* 'default' uses the server's CLI setting.
|
||||
@@ -872,7 +867,6 @@ export class ComfyApi extends EventTarget {
|
||||
extra_data: {
|
||||
auth_token_comfy_org: this.authToken,
|
||||
api_key_comfy_org: this.apiKey,
|
||||
comfy_usage_source: 'comfyui-frontend',
|
||||
extra_pnginfo: { workflow },
|
||||
...(options?.previewMethod &&
|
||||
options.previewMethod !== 'default' && {
|
||||
|
||||