testing oidc callback

This commit is contained in:
miloschwartz
2025-04-12 15:39:15 -04:00
parent 9cb215295a
commit 480a5f648d
15 changed files with 997 additions and 7 deletions

View File

@@ -65,7 +65,8 @@ export enum ActionsEnum {
listResourceRules = "listResourceRules",
updateResourceRule = "updateResourceRule",
listOrgDomains = "listOrgDomains",
createNewt = "createNewt"
createNewt = "createNewt",
createIdp = "createIdp"
}
export async function checkUserActionPermission(

View File

@@ -0,0 +1,126 @@
import { encodeHexLowerCase } from "@oslojs/encoding";
import { sha256 } from "@oslojs/crypto/sha2";
import {
IdpSession,
idpSessions,
IdpUser,
idpUser,
resourceSessions
} from "@server/db/schemas";
import db from "@server/db";
import { eq } from "drizzle-orm";
import logger from "@server/logger";
import config from "@server/lib/config";
import cookie from "cookie";
const SESSION_COOKIE_EXPIRES =
1000 *
60 *
60 *
config.getRawConfig().server.dashboard_session_length_hours;
const COOKIE_DOMAIN =
"." + new URL(config.getRawConfig().app.dashboard_url).hostname;
export async function createIdpSession(
token: string,
idpUserId: string
): Promise<IdpSession> {
const sessionId = encodeHexLowerCase(
sha256(new TextEncoder().encode(token))
);
const session: IdpSession = {
idpSessionId: sessionId,
idpUserId,
expiresAt: new Date(Date.now() + SESSION_COOKIE_EXPIRES).getTime()
};
await db.insert(idpSessions).values(session);
return session;
}
export async function validateIdpSessionToken(
token: string
): Promise<IdpSessionValidationResult> {
const idpSessionId = encodeHexLowerCase(
sha256(new TextEncoder().encode(token))
);
const result = await db
.select({ idpUser: idpUser, idpSession: idpSessions })
.from(idpSessions)
.innerJoin(idpUser, eq(idpSessions.idpUserId, idpUser.idpUserId))
.where(eq(idpSessions.idpSessionId, idpSessionId));
if (result.length < 1) {
return { session: null, user: null };
}
const { idpUser: idpUserRes, idpSession: idpSessionRes } = result[0];
if (Date.now() >= idpSessionRes.expiresAt) {
await db
.delete(idpSessions)
.where(eq(idpSessions.idpSessionId, idpSessionRes.idpSessionId));
return { session: null, user: null };
}
if (Date.now() >= idpSessionRes.expiresAt - SESSION_COOKIE_EXPIRES / 2) {
idpSessionRes.expiresAt = new Date(
Date.now() + SESSION_COOKIE_EXPIRES
).getTime();
await db.transaction(async (trx) => {
await trx
.update(idpSessions)
.set({
expiresAt: idpSessionRes.expiresAt
})
.where(
eq(idpSessions.idpSessionId, idpSessionRes.idpSessionId)
);
await trx
.update(resourceSessions)
.set({
expiresAt: idpSessionRes.expiresAt
})
.where(
eq(
resourceSessions.idpSessionId,
idpSessionRes.idpSessionId
)
);
});
}
return { session: idpSessionRes, user: idpUserRes };
}
export async function invalidateIdpSession(
idpSessionId: string
): Promise<void> {
try {
await db.transaction(async (trx) => {
await trx
.delete(resourceSessions)
.where(eq(resourceSessions.idpSessionId, idpSessionId));
await trx
.delete(idpSessions)
.where(eq(idpSessions.idpSessionId, idpSessionId));
});
} catch (e) {
logger.error("Failed to invalidate session", e);
}
}
export function serializeIdpSessionCookie(
cookieName: string,
token: string,
isSecure: boolean,
expiresAt: Date
): string {
return cookie.serialize(cookieName, token, {
httpOnly: true,
sameSite: "lax",
expires: expiresAt,
path: "/",
secure: isSecure,
domain: COOKIE_DOMAIN
});
}
export type IdpSessionValidationResult =
| { session: IdpSession; user: IdpUser }
| { session: null; user: null };

View File

@@ -340,6 +340,12 @@ export const resourceSessions = sqliteTable("resourceSessions", {
.notNull()
.default(false),
isRequestToken: integer("isRequestToken", { mode: "boolean" }),
idpSessionId: text("idpSessionId").references(
() => idpSessions.idpSessionId,
{
onDelete: "cascade"
}
),
userSessionId: text("userSessionId").references(() => sessions.sessionId, {
onDelete: "cascade"
}),
@@ -415,6 +421,77 @@ export const supporterKey = sqliteTable("supporterKey", {
valid: integer("valid", { mode: "boolean" }).notNull().default(false)
});
// Identity Providers
export const idp = sqliteTable("idp", {
idpId: integer("idpId").primaryKey({ autoIncrement: true }),
type: text("type").notNull()
});
// Identity Provider OAuth Configuration
export const idpOidcConfig = sqliteTable("idpOidcConfig", {
idpOauthConfigId: integer("idpOauthConfigId").primaryKey({
autoIncrement: true
}),
idpId: integer("idpId")
.notNull()
.references(() => idp.idpId, { onDelete: "cascade" }),
clientId: text("clientId").notNull(),
clientSecret: text("clientSecret").notNull(),
authUrl: text("authUrl").notNull(),
tokenUrl: text("tokenUrl").notNull(),
autoProvision: integer("autoProvision", {
mode: "boolean"
})
.notNull()
.default(false),
identifierPath: text("identifierPath").notNull(),
emailPath: text("emailPath"), // by default, this is "email"
namePath: text("namePath"), // by default, this is "name"
roleMapping: text("roleMapping"),
scopes: text("scopes").notNull()
});
export const idpOrg = sqliteTable("idpOrg", {
idpId: integer("idpId")
.notNull()
.references(() => idp.idpId, { onDelete: "cascade" }),
orgId: text("orgId")
.notNull()
.references(() => orgs.orgId, { onDelete: "cascade" })
});
// IDP User
export const idpUser = sqliteTable("idpUser", {
idpUserId: text("idpUserId").primaryKey(),
identifier: text("identifier").notNull(),
idpId: integer("idpId")
.notNull()
.references(() => idp.idpId, { onDelete: "cascade" }),
email: text("email"),
name: text("name")
});
// IDP User Organization Link
export const idpUserOrg = sqliteTable("idpUserOrg", {
idpUserId: text("idpUserId")
.notNull()
.references(() => idpUser.idpUserId, { onDelete: "cascade" }),
orgId: text("orgId")
.notNull()
.references(() => orgs.orgId, { onDelete: "cascade" }),
roleId: integer("roleId")
.notNull()
.references(() => roles.roleId, { onDelete: "cascade" })
});
export const idpSessions = sqliteTable("idpSessions", {
idpSessionId: text("idpSessionId").primaryKey(),
idpUserId: text("idpUserId")
.notNull()
.references(() => idpUser.idpUserId, { onDelete: "cascade" }),
expiresAt: integer("expiresAt").notNull()
});
export type Org = InferSelectModel<typeof orgs>;
export type User = InferSelectModel<typeof users>;
export type Site = InferSelectModel<typeof sites>;
@@ -450,3 +527,8 @@ export type VersionMigration = InferSelectModel<typeof versionMigrations>;
export type ResourceRule = InferSelectModel<typeof resourceRules>;
export type Domain = InferSelectModel<typeof domains>;
export type SupporterKey = InferSelectModel<typeof supporterKey>;
export type Idp = InferSelectModel<typeof idp>;
export type IdpUser = InferSelectModel<typeof idpUser>;
export type IdpOrg = InferSelectModel<typeof idpOrg>;
export type IdpUserOrg = InferSelectModel<typeof idpUserOrg>;
export type IdpSession = InferSelectModel<typeof idpSessions>;

View File

@@ -0,0 +1,8 @@
import config from "@server/lib/config";
export function generateOidcRedirectUrl(orgId: string, idpId: number) {
const dashboardUrl = config.getRawConfig().app.dashboard_url;
const redirectPath = `/auth/org/${orgId}/idp/${idpId}/oidc/callback`;
const redirectUrl = new URL(redirectPath, dashboardUrl).toString();
return redirectUrl;
}

View File

@@ -11,5 +11,6 @@ export enum OpenAPITags {
Invitation = "Invitation",
Target = "Target",
Rule = "Rule",
AccessToken = "Access Token"
AccessToken = "Access Token",
Idp = "Identity Provider"
}

View File

@@ -10,6 +10,7 @@ import * as auth from "./auth";
import * as role from "./role";
import * as supporterKey from "./supporterKey";
import * as accessToken from "./accessToken";
import * as idp from "./idp";
import HttpCode from "@server/types/HttpCode";
import {
verifyAccessTokenAccess,
@@ -493,6 +494,13 @@ authenticated.delete(
// createNewt
// );
authenticated.put(
"/org/:orgId/idp/oidc",
verifyOrgAccess,
verifyUserHasAction(ActionsEnum.createIdp),
idp.createOidcIdp
)
// Auth routes
export const authRouter = Router();
unauthenticated.use("/auth", authRouter);
@@ -581,4 +589,17 @@ authRouter.post(
resource.authWithAccessToken
);
authRouter.post("/access-token", resource.authWithAccessToken);
authRouter.post(
"/access-token",
resource.authWithAccessToken
);
authRouter.post(
"/org/:orgId/idp/:idpId/oidc/generate-url",
idp.generateOidcUrl
)
authRouter.post(
"/org/:orgId/idp/:idpId/oidc/validate-callback",
idp.validateOidcCallback
)

View File

@@ -0,0 +1,158 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
import { idp, idpOidcConfig, idpOrg, orgs } from "@server/db/schemas";
import { eq } from "drizzle-orm";
import { generateOidcUrl } from "./generateOidcUrl";
import { generateOidcRedirectUrl } from "@server/lib/idp/generateRedirectUrl";
const paramsSchema = z
.object({
orgId: z.string()
})
.strict();
const bodySchema = z
.object({
clientId: z.string().nonempty(),
clientSecret: z.string().nonempty(),
authUrl: z.string().url(),
tokenUrl: z.string().url(),
autoProvision: z.boolean(),
identifierPath: z.string().nonempty(),
emailPath: z.string().optional(),
namePath: z.string().optional(),
roleMapping: z.string().optional(),
scopes: z.array(z.string().nonempty())
})
.strict();
export type CreateIdpResponse = {
idpId: number;
redirectUrl: string;
};
registry.registerPath({
method: "put",
path: "/org/{orgId}/idp/oidc",
description: "Create an OIDC IdP for an organization.",
tags: [OpenAPITags.Org, OpenAPITags.Idp],
request: {
params: paramsSchema,
body: {
content: {
"application/json": {
schema: bodySchema
}
}
}
},
responses: {}
});
export async function createOidcIdp(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = paramsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const parsedBody = bodySchema.safeParse(req.body);
if (!parsedBody.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
const { orgId } = parsedParams.data;
const {
clientId,
clientSecret,
authUrl,
tokenUrl,
scopes,
identifierPath,
emailPath,
namePath,
roleMapping,
autoProvision
} = parsedBody.data;
// Check if the org exists
const [org] = await db.select().from(orgs).where(eq(orgs.orgId, orgId));
if (!org) {
return next(
createHttpError(HttpCode.NOT_FOUND, "Organization not found")
);
}
let idpId: number | undefined;
await db.transaction(async (trx) => {
const [idpRes] = await trx
.insert(idp)
.values({
type: "oidc"
})
.returning();
idpId = idpRes.idpId;
await trx.insert(idpOidcConfig).values({
idpId: idpRes.idpId,
clientId,
clientSecret,
authUrl,
tokenUrl,
autoProvision,
scopes: JSON.stringify(scopes),
identifierPath,
emailPath,
namePath,
roleMapping
});
await trx.insert(idpOrg).values({
idpId: idpRes.idpId,
orgId
});
});
const redirectUrl = generateOidcRedirectUrl(orgId, idpId as number);
return response<CreateIdpResponse>(res, {
data: {
idpId: idpId as number,
redirectUrl
},
success: true,
error: false,
message: "Idp created successfully",
status: HttpCode.CREATED
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View File

@@ -0,0 +1,116 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { idp, idpOidcConfig, idpOrg } from "@server/db/schemas";
import { and, eq } from "drizzle-orm";
import * as arctic from "arctic";
import { generateOidcRedirectUrl } from "@server/lib/idp/generateRedirectUrl";
import cookie from "cookie";
const paramsSchema = z
.object({
orgId: z.string(),
idpId: z.coerce.number()
})
.strict();
export type GenerateOidcUrlResponse = {
redirectUrl: string;
};
export async function generateOidcUrl(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = paramsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { orgId, idpId } = parsedParams.data;
const [existingIdp] = await db
.select()
.from(idp)
.innerJoin(idpOrg, eq(idp.idpId, idpOrg.idpId))
.innerJoin(idpOidcConfig, eq(idpOidcConfig.idpId, idp.idpId))
.where(
and(
eq(idpOrg.orgId, orgId),
eq(idp.type, "oidc"),
eq(idp.idpId, idpId)
)
);
if (!existingIdp) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"IdP not found for the organization"
)
);
}
const parsedScopes = JSON.parse(existingIdp.idpOidcConfig.scopes);
const redirectUrl = generateOidcRedirectUrl(orgId, idpId);
const client = new arctic.OAuth2Client(
existingIdp.idpOidcConfig.clientId,
existingIdp.idpOidcConfig.clientSecret,
redirectUrl
);
const codeVerifier = arctic.generateCodeVerifier();
const state = arctic.generateState();
const url = client.createAuthorizationURLWithPKCE(
existingIdp.idpOidcConfig.authUrl,
state,
arctic.CodeChallengeMethod.S256,
codeVerifier,
parsedScopes
);
res.cookie("oidc_state", state, {
path: "/",
httpOnly: true,
secure: req.protocol === "https",
expires: new Date(Date.now() + 60 * 10 * 1000),
sameSite: "lax"
});
res.cookie(`oidc_code_verifier`, codeVerifier, {
path: "/",
httpOnly: true,
secure: req.protocol === "https",
expires: new Date(Date.now() + 60 * 10 * 1000),
sameSite: "lax"
});
return response<GenerateOidcUrlResponse>(res, {
data: {
redirectUrl: url.toString()
},
success: true,
error: false,
message: "Idp auth url generated",
status: HttpCode.CREATED
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View File

@@ -0,0 +1,3 @@
export * from "./createOidcIdp";
export * from "./generateOidcUrl";
export * from "./validateOidcCallback";

View File

@@ -0,0 +1,250 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import {
idp,
idpOidcConfig,
idpOrg,
idpUser,
idpUserOrg,
Role,
roles
} from "@server/db/schemas";
import { and, eq } from "drizzle-orm";
import * as arctic from "arctic";
import { generateOidcRedirectUrl } from "@server/lib/idp/generateRedirectUrl";
import jmespath from "jmespath";
import { generateId, generateSessionToken } from "@server/auth/sessions/app";
import {
createIdpSession,
serializeIdpSessionCookie
} from "@server/auth/sessions/orgIdp";
const paramsSchema = z
.object({
orgId: z.string(),
idpId: z.coerce.number()
})
.strict();
const bodySchema = z.object({
code: z.string().nonempty(),
codeVerifier: z.string().nonempty()
});
export type ValidateOidcUrlCallbackResponse = {};
export async function validateOidcCallback(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = paramsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { orgId, idpId } = parsedParams.data;
const parsedBody = bodySchema.safeParse(req.body);
if (!parsedBody.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
const { code, codeVerifier } = parsedBody.data;
const [existingIdp] = await db
.select()
.from(idp)
.innerJoin(idpOrg, eq(idp.idpId, idpOrg.idpId))
.innerJoin(idpOidcConfig, eq(idpOidcConfig.idpId, idp.idpId))
.where(
and(
eq(idpOrg.orgId, orgId),
eq(idp.type, "oidc"),
eq(idp.idpId, idpId)
)
);
if (!existingIdp) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"IdP not found for the organization"
)
);
}
const redirectUrl = generateOidcRedirectUrl(
orgId,
existingIdp.idp.idpId
);
const client = new arctic.OAuth2Client(
existingIdp.idpOidcConfig.clientId,
existingIdp.idpOidcConfig.clientSecret,
redirectUrl
);
const tokens = await client.validateAuthorizationCode(
existingIdp.idpOidcConfig.tokenUrl,
code,
codeVerifier
);
const idToken = tokens.idToken();
const claims = arctic.decodeIdToken(idToken);
const userIdentifier = jmespath.search(
claims,
existingIdp.idpOidcConfig.identifierPath
);
if (!userIdentifier) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"User identifier not found in the ID token"
)
);
}
logger.debug("User identifier", { userIdentifier });
const email = jmespath.search(
claims,
existingIdp.idpOidcConfig.emailPath || "email"
);
const name = jmespath.search(
claims,
existingIdp.idpOidcConfig.namePath || "name"
);
logger.debug("User email", { email });
logger.debug("User name", { name });
const [existingIdpUser] = await db
.select()
.from(idpUser)
.innerJoin(idpUserOrg, eq(idpUserOrg.idpUserId, idpUser.idpUserId))
.where(
and(
eq(idpUserOrg.orgId, orgId),
eq(idpUser.idpId, existingIdp.idp.idpId)
)
);
let userRole: Role | undefined;
if (existingIdp.idpOidcConfig.roleMapping) {
const roleName = jmespath.search(
claims,
existingIdp.idpOidcConfig.roleMapping
);
if (!roleName) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Role mapping not found in the ID token"
)
);
}
const [roleRes] = await db
.select()
.from(roles)
.where(and(eq(roles.orgId, orgId), eq(roles.name, roleName)));
userRole = roleRes;
} else {
// TODO: Get the default role for this IDP?
}
logger.debug("User role", { userRole });
if (!userRole) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Role not found for the user"
)
);
}
let userId: string | undefined = existingIdpUser?.idpUser.idpUserId;
if (!existingIdpUser) {
if (existingIdp.idpOidcConfig.autoProvision) {
// TODO: Create the user and automatically assign roles
await db.transaction(async (trx) => {
const idpUserId = generateId(15);
const [idpUserRes] = await trx
.insert(idpUser)
.values({
idpUserId,
idpId: existingIdp.idp.idpId,
identifier: userIdentifier,
email,
name
})
.returning();
await trx.insert(idpUserOrg).values({
idpUserId: idpUserRes.idpUserId,
orgId,
roleId: userRole.roleId
});
userId = idpUserRes.idpUserId;
});
} else {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"User not found and auto-provisioning is disabled"
)
);
}
}
const token = generateSessionToken();
const sess = await createIdpSession(token, userId);
const cookie = serializeIdpSessionCookie(
`p_idp_${orgId}.${idpId}`,
sess.idpSessionId,
req.protocol === "https",
new Date(sess.expiresAt)
);
res.setHeader("Set-Cookie", cookie);
return response<ValidateOidcUrlCallbackResponse>(res, {
data: {},
success: true,
error: false,
message: "OIDC callback validated successfully",
status: HttpCode.CREATED
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}