Handle labels for machine clients

This commit is contained in:
Fred KISSIE
2026-05-12 22:32:56 +02:00
parent 7120ab4b22
commit ce746a2a21
7 changed files with 320 additions and 30 deletions

View File

@@ -227,6 +227,24 @@ export const siteResourceLabels = pgTable(
(t) => [unique("site_resource_label_uniq").on(t.siteResourceId, t.labelId)] (t) => [unique("site_resource_label_uniq").on(t.siteResourceId, t.labelId)]
); );
export const clientLabels = pgTable(
"clientLabels",
{
clientLabelId: serial("clientLabelId").primaryKey(),
clientId: integer("clientId")
.references(() => clients.clientId, {
onDelete: "cascade"
})
.notNull(),
labelId: integer("labelId")
.references(() => labels.labelId, {
onDelete: "cascade"
})
.notNull()
},
(t) => [unique("client_label_uniq").on(t.clientId, t.labelId)]
);
export const targets = pgTable("targets", { export const targets = pgTable("targets", {
targetId: serial("targetId").primaryKey(), targetId: serial("targetId").primaryKey(),
resourceId: integer("resourceId") resourceId: integer("resourceId")

View File

@@ -252,6 +252,26 @@ export const siteResourceLabels = sqliteTable(
(t) => [unique("site_resource_label_uniq").on(t.siteResourceId, t.labelId)] (t) => [unique("site_resource_label_uniq").on(t.siteResourceId, t.labelId)]
); );
export const clientLabels = sqliteTable(
"clientLabels",
{
clientLabelId: integer("clientLabelId").primaryKey({
autoIncrement: true
}),
clientId: integer("clientId")
.references(() => clients.clientId, {
onDelete: "cascade"
})
.notNull(),
labelId: integer("labelId")
.references(() => labels.labelId, {
onDelete: "cascade"
})
.notNull()
},
(t) => [unique("client_label_uniq").on(t.clientId, t.labelId)]
);
export const targets = sqliteTable("targets", { export const targets = sqliteTable("targets", {
targetId: integer("targetId").primaryKey({ autoIncrement: true }), targetId: integer("targetId").primaryKey({ autoIncrement: true }),
resourceId: integer("resourceId") resourceId: integer("resourceId")

View File

@@ -12,6 +12,8 @@
*/ */
import { import {
clients,
clientLabels,
db, db,
labels, labels,
resourceLabels, resourceLabels,
@@ -24,7 +26,7 @@ import {
import response from "@server/lib/response"; import response from "@server/lib/response";
import logger from "@server/logger"; import logger from "@server/logger";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
import { and, eq } from "drizzle-orm"; import { and, eq, isNull } from "drizzle-orm";
import { NextFunction, Request, Response } from "express"; import { NextFunction, Request, Response } from "express";
import createHttpError from "http-errors"; import createHttpError from "http-errors";
import { z } from "zod"; import { z } from "zod";
@@ -38,7 +40,8 @@ const paramsSchema = z.strictObject({
const attachLabelBodySchema = z.strictObject({ const attachLabelBodySchema = z.strictObject({
siteId: z.number().int().optional(), siteId: z.number().int().optional(),
resourceId: z.number().int().optional(), resourceId: z.number().int().optional(),
siteResourceId: z.number().int().optional() siteResourceId: z.number().int().optional(),
clientId: z.number().int().optional()
}); });
export async function attachLabelToItem( export async function attachLabelToItem(
@@ -69,13 +72,14 @@ export async function attachLabelToItem(
); );
} }
const { siteId, resourceId, siteResourceId } = parsedBody.data; const { siteId, resourceId, siteResourceId, clientId } =
parsedBody.data;
if (!siteId && !resourceId && !siteResourceId) { if (!siteId && !resourceId && !siteResourceId && !clientId) {
return next( return next(
createHttpError( createHttpError(
HttpCode.BAD_REQUEST, HttpCode.BAD_REQUEST,
"At least one of `siteId`, `resourceId` or `siteResourceId` should be provided." "At least one of `siteId`, `resourceId`, `siteResourceId` or `clientId` should be provided."
) )
); );
} }
@@ -175,6 +179,35 @@ export async function attachLabelToItem(
.onConflictDoNothing(); .onConflictDoNothing();
} }
if (clientId) {
const clientCount = await db.$count(
clients,
and(
eq(clients.clientId, clientId),
eq(clients.orgId, orgId),
isNull(clients.userId)
)
);
if (clientCount === 0) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Client with Id ${clientId} doesn't exist.`
)
);
}
// idempotent, calling this endpoint multiple times should attach the label only once
await db
.insert(clientLabels)
.values({
labelId,
clientId
})
.onConflictDoNothing();
}
return response(res, { return response(res, {
data: {}, data: {},
success: true, success: true,

View File

@@ -12,6 +12,8 @@
*/ */
import { import {
clients,
clientLabels,
db, db,
labels, labels,
resourceLabels, resourceLabels,
@@ -24,7 +26,7 @@ import {
import response from "@server/lib/response"; import response from "@server/lib/response";
import logger from "@server/logger"; import logger from "@server/logger";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
import { and, eq } from "drizzle-orm"; import { and, eq, isNull } from "drizzle-orm";
import { NextFunction, Request, Response } from "express"; import { NextFunction, Request, Response } from "express";
import createHttpError from "http-errors"; import createHttpError from "http-errors";
import { z } from "zod"; import { z } from "zod";
@@ -38,7 +40,8 @@ const paramsSchema = z.strictObject({
const detachLabelBodySchema = z.strictObject({ const detachLabelBodySchema = z.strictObject({
siteId: z.number().int().optional(), siteId: z.number().int().optional(),
resourceId: z.number().int().optional(), resourceId: z.number().int().optional(),
siteResourceId: z.number().int().optional() siteResourceId: z.number().int().optional(),
clientId: z.number().int().optional()
}); });
export async function detachLabelFromItem( export async function detachLabelFromItem(
@@ -69,13 +72,14 @@ export async function detachLabelFromItem(
); );
} }
const { siteId, resourceId, siteResourceId } = parsedBody.data; const { siteId, resourceId, siteResourceId, clientId } =
parsedBody.data;
if (!siteId && !resourceId && !siteResourceId) { if (!siteId && !resourceId && !siteResourceId && !clientId) {
return next( return next(
createHttpError( createHttpError(
HttpCode.BAD_REQUEST, HttpCode.BAD_REQUEST,
"At least one of `siteId`, `siteResourceId` or `resourceId` should be provided." "At least one of `siteId`, `resourceId`, `siteResourceId` or `clientId` should be provided."
) )
); );
} }
@@ -175,6 +179,35 @@ export async function detachLabelFromItem(
); );
} }
if (clientId) {
const clientCount = await db.$count(
clients,
and(
eq(clients.clientId, clientId),
eq(clients.orgId, orgId),
isNull(clients.userId)
)
);
if (clientCount === 0) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Client with Id ${clientId} doesn't exist.`
)
);
}
await db
.delete(clientLabels)
.where(
and(
eq(clientLabels.labelId, labelId),
eq(clientLabels.clientId, clientId)
)
);
}
return response(res, { return response(res, {
data: {}, data: {},
success: true, success: true,

View File

@@ -1,15 +1,20 @@
import { import {
clientLabels,
clients, clients,
clientSitesAssociationsCache, clientSitesAssociationsCache,
currentFingerprint, currentFingerprint,
db, db,
labels,
olms, olms,
orgs, orgs,
roleClients, roleClients,
sites, sites,
userClients, userClients,
users users,
type Label
} from "@server/db"; } from "@server/db";
import { isLicensedOrSubscribed } from "#dynamic/lib/isLicencedOrSubscribed";
import { tierMatrix } from "@server/lib/billing/tierMatrix";
import response from "@server/lib/response"; import response from "@server/lib/response";
import logger from "@server/logger"; import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi"; import { OpenAPITags, registry } from "@server/openApi";
@@ -169,6 +174,7 @@ type ClientWithSites = Awaited<ReturnType<typeof queryClientsBase>>[0] & {
siteNiceId: string | null; siteNiceId: string | null;
}>; }>;
olmUpdateAvailable?: boolean; olmUpdateAvailable?: boolean;
labels?: Array<Pick<Label, "labelId" | "name" | "color">>;
}; };
type OlmWithUpdateAvailable = ClientWithSites; type OlmWithUpdateAvailable = ClientWithSites;
@@ -255,6 +261,11 @@ export async function listClients(
(client) => client.clientId (client) => client.clientId
); );
const isLabelFeatureEnabled = await isLicensedOrSubscribed(
orgId,
tierMatrix.labels
);
// Get client count with filter // Get client count with filter
const conditions = [ const conditions = [
and( and(
@@ -288,18 +299,29 @@ export async function listClients(
} }
if (query) { if (query) {
conditions.push( const q = "%" + query.toLowerCase() + "%";
or( const queryList = [
like( like(sql`LOWER(${clients.name})`, q),
sql`LOWER(${clients.name})`, like(sql`LOWER(${clients.niceId})`, q)
"%" + query.toLowerCase() + "%" ];
),
like( if (isLabelFeatureEnabled) {
sql`LOWER(${clients.niceId})`, queryList.push(
"%" + query.toLowerCase() + "%" inArray(
clients.clientId,
db
.select({ id: clientLabels.clientId })
.from(clientLabels)
.innerJoin(
labels,
eq(labels.labelId, clientLabels.labelId)
)
.where(like(sql`LOWER(${labels.name})`, q))
) )
) );
); }
conditions.push(or(...queryList));
} }
const baseQuery = queryClientsBase().where(and(...conditions)); const baseQuery = queryClientsBase().where(and(...conditions));
@@ -326,6 +348,30 @@ export async function listClients(
const clientIds = clientsList.map((client) => client.clientId); const clientIds = clientsList.map((client) => client.clientId);
const siteAssociations = await getSiteAssociations(clientIds); const siteAssociations = await getSiteAssociations(clientIds);
let labelsForClients: Array<{
labelId: number;
name: string;
color: string;
clientId: number;
}> = [];
if (isLabelFeatureEnabled && clientIds.length > 0) {
labelsForClients = await db
.select({
labelId: labels.labelId,
name: labels.name,
color: labels.color,
clientId: clientLabels.clientId
})
.from(labels)
.innerJoin(
clientLabels,
eq(clientLabels.labelId, labels.labelId)
)
.where(inArray(clientLabels.clientId, clientIds))
.orderBy(asc(clientLabels.clientLabelId));
}
// Group site associations by client ID // Group site associations by client ID
const sitesByClient = siteAssociations.reduce( const sitesByClient = siteAssociations.reduce(
(acc, association) => { (acc, association) => {
@@ -353,7 +399,10 @@ export async function listClients(
const clientsWithSites = clientsList.map((client) => { const clientsWithSites = clientsList.map((client) => {
return { return {
...client, ...client,
sites: sitesByClient[client.clientId] || [] sites: sitesByClient[client.clientId] || [],
labels: labelsForClients.filter(
(l) => l.clientId === client.clientId
)
}; };
}); });

View File

@@ -76,7 +76,8 @@ export default async function ClientsPage(props: ClientsPageProps) {
agent: client.agent, agent: client.agent,
archived: client.archived || false, archived: client.archived || false,
blocked: client.blocked || false, blocked: client.blocked || false,
approvalState: client.approvalState ?? "approved" approvalState: client.approvalState ?? "approved",
labels: client.labels ?? []
}; };
}; };

View File

@@ -10,8 +10,11 @@ import {
DropdownMenuTrigger DropdownMenuTrigger
} from "@app/components/ui/dropdown-menu"; } from "@app/components/ui/dropdown-menu";
import { useEnvContext } from "@app/hooks/useEnvContext"; import { useEnvContext } from "@app/hooks/useEnvContext";
import { usePaidStatus } from "@app/hooks/usePaidStatus";
import { toast } from "@app/hooks/useToast"; import { toast } from "@app/hooks/useToast";
import { createApiClient, formatAxiosError } from "@app/lib/api"; import { createApiClient, formatAxiosError } from "@app/lib/api";
import { cn } from "@app/lib/cn";
import { tierMatrix } from "@server/lib/billing/tierMatrix";
import { import {
ArrowRight, ArrowRight,
ArrowUpDown, ArrowUpDown,
@@ -19,12 +22,26 @@ import {
CircleSlash, CircleSlash,
ArrowDown01Icon, ArrowDown01Icon,
ArrowUp10Icon, ArrowUp10Icon,
ChevronsUpDownIcon ChevronsUpDownIcon,
PlusIcon
} from "lucide-react"; } from "lucide-react";
import { useTranslations } from "next-intl"; import { useTranslations } from "next-intl";
import Link from "next/link"; import Link from "next/link";
import { useRouter } from "next/navigation"; import { useRouter } from "next/navigation";
import { useMemo, useState, useTransition } from "react"; import {
startTransition,
useMemo,
useOptimistic,
useState,
useTransition
} from "react";
import { LabelBadge } from "./label-badge";
import { LabelsSelector, type SelectedLabel } from "./labels-selector";
import {
Popover,
PopoverContent,
PopoverTrigger
} from "./ui/popover";
import { Badge } from "./ui/badge"; import { Badge } from "./ui/badge";
import type { PaginationState } from "@tanstack/react-table"; import type { PaginationState } from "@tanstack/react-table";
import { ControlledDataTable } from "./ui/controlled-data-table"; import { ControlledDataTable } from "./ui/controlled-data-table";
@@ -53,6 +70,11 @@ export type ClientRow = {
archived?: boolean; archived?: boolean;
blocked?: boolean; blocked?: boolean;
approvalState: "approved" | "pending" | "denied"; approvalState: "approved" | "pending" | "denied";
labels?: Array<{
labelId: number;
name: string;
color: string;
}>;
}; };
type ClientTableProps = { type ClientTableProps = {
@@ -84,17 +106,21 @@ export default function MachineClientsTable({
); );
const api = createApiClient(useEnvContext()); const api = createApiClient(useEnvContext());
const [isRefreshing, startTransition] = useTransition(); const [isRefreshing, startRefreshTransition] = useTransition();
const [isNavigatingToAddPage, startNavigation] = useTransition(); const [isNavigatingToAddPage, startNavigation] = useTransition();
const { isPaidUser } = usePaidStatus();
const isLabelFeatureEnabled = isPaidUser(tierMatrix.labels);
const defaultMachineColumnVisibility = { const defaultMachineColumnVisibility = {
subnet: false, subnet: false,
userId: false, userId: false,
niceId: false niceId: false,
labels: false
}; };
const refreshData = () => { const refreshData = () => {
startTransition(() => { startRefreshTransition(() => {
try { try {
router.refresh(); router.refresh();
} catch (error) { } catch (error) {
@@ -384,6 +410,24 @@ export default function MachineClientsTable({
} }
]; ];
if (isLabelFeatureEnabled) {
baseColumns.push({
id: "labels",
accessorKey: "labels",
header: () => (
<span className="p-3 text-end w-full inline-block">
{t("labels")}
</span>
),
cell: ({ row }: { row: { original: ClientRow } }) => (
<MachineClientLabelCell
client={row.original}
orgId={orgId}
/>
)
});
}
// Only include actions column if there are rows without userIds // Only include actions column if there are rows without userIds
if (hasRowsWithoutUserId) { if (hasRowsWithoutUserId) {
baseColumns.push({ baseColumns.push({
@@ -464,7 +508,7 @@ export default function MachineClientsTable({
} }
return baseColumns; return baseColumns;
}, [hasRowsWithoutUserId, t, getSortDirection, toggleSort]); }, [hasRowsWithoutUserId, isLabelFeatureEnabled, orgId, t, searchParams]);
const booleanSearchFilterSchema = z const booleanSearchFilterSchema = z
.enum(["true", "false"]) .enum(["true", "false"])
@@ -591,3 +635,95 @@ export default function MachineClientsTable({
</> </>
); );
} }
type MachineClientLabelCellProps = {
client: ClientRow;
orgId: string;
};
function MachineClientLabelCell({ client, orgId }: MachineClientLabelCellProps) {
const t = useTranslations();
const api = createApiClient(useEnvContext());
const [isPopoverOpen, setIsPopoverOpen] = useState(false);
const router = useRouter();
const labels = client.labels ?? [];
const [optimisticLabels, setOptimisticLabels] = useOptimistic(labels);
function toggleClientLabel(label: SelectedLabel, action: "attach" | "detach") {
startTransition(async () => {
try {
if (action === "attach") {
setOptimisticLabels([...optimisticLabels, label]);
await api.put(
`/org/${orgId}/label/${label.labelId}/attach`,
{ clientId: client.id }
);
} else {
setOptimisticLabels(
optimisticLabels.filter(
(lb) => lb.labelId !== label.labelId
)
);
await api.put(
`/org/${orgId}/label/${label.labelId}/detach`,
{ clientId: client.id }
);
}
} catch (e) {
toast({
title: t("error"),
description: formatAxiosError(e, t("errorOccurred")),
variant: "destructive"
});
} finally {
router.refresh();
}
});
}
return (
<div className="inline-flex flex-wrap items-center justify-end w-full gap-1">
{optimisticLabels.slice(0, 3).map((label) => (
<LabelBadge
key={label.labelId}
onClick={() => setIsPopoverOpen(true)}
{...label}
/>
))}
{optimisticLabels.length > 3 && (
<Button
variant="outline"
className={cn(
"inline-flex gap-1 items-center",
"rounded-full text-sm cursor-pointer",
"px-1.5 py-0 h-auto"
)}
onClick={() => setIsPopoverOpen(true)}
>
+{optimisticLabels.length - 3}
</Button>
)}
<Popover open={isPopoverOpen} onOpenChange={setIsPopoverOpen}>
<PopoverTrigger asChild>
<Button
size="icon"
variant="outline"
className="p-1 size-auto rounded-full"
title={t("addLabels")}
>
<span className="sr-only">{t("addLabels")}</span>
<PlusIcon className="size-3" />
</Button>
</PopoverTrigger>
<PopoverContent align="center" className="p-0 w-full">
<LabelsSelector
orgId={orgId}
selectedLabels={optimisticLabels}
toggleLabel={toggleClientLabel}
/>
</PopoverContent>
</Popover>
</div>
);
}